深度学习竞赛中常见的一种手段:测试时增强(TTA)
前言
在很多深度学习的比赛项目中,各种方法trick层出不穷,其中有一种颇受争议的方法就是在测试时使用增强的手段,将输入的源图片生成多份分别送入模型,然后对所有的推理结果做一个综合整合。这种方法被称为测试时增强(test time augmentation, TTA),今天我们就来说说这个测试时增强。
TTA流程
TTA的基本流程是通过对原图做增强操作,获得很多份增强后的样本与原图组成一个数据组,然后用这些样本获取推理结果,最后把多份的推理结果按一定方法合成得到最后的推理结果再进行精度指标计算。流程图如下:
这么看上去需要确认很多问题:
原图片需要用什么增强方法来生成新的样本。
生成的样本在获取推理结果之后应该使用什么样的方法进行合成。
我们举个简单的例子来说明TTA的作用以及如何利用ModelArts平台提供的功能来使用TTA。
TTA使用实例
实验环境:
数据集:在前一篇博客中,我有说明一种用于解决数据过少和不均衡现象的方法,同样我们也使用这个太阳能电板的数据集。数据集样例图片如下:
其中左侧为正常样本图像,共754张,右侧为有瑕疵的电板图像,共358张,经过一定的增强手段后扩充至1508张正常类图片,1432张有瑕疵类图片,关于扩充的方法和过程详见我的上一篇博客https://bbs.huaweicloud.com/blogs/189148。
使用框架及算法:pytorch官方提供的训练imagenet开源代码,参考https://github.com/pytorch/examples/tree/master/imagenet
训练策略:50个epoch,初始学习率lr0.001,batchsize16用Adam的优化器训练。
原模型精度信息:
精度信息 | 正常类 | 有瑕疵非正常类 |
召回率recall | 97.2% | 71.3% |
精度值accuracy | 89.13% |
TTA过程:
首先,我们要选定使用的增强方法来获取多样本。这里有两种方法:
(1).从训练中使用的增强手段入手,用训练中使用的增强手段获取多样本。
如pytorch训练imagenet的代码中,使用了算子transforms.RandomHorizontalFlip()做垂直方向的翻转操作。那么对于模型而言,应该也见过很多经过垂直翻转的图片,所以我们可以用垂直方向的翻转来作为增强手段的一种。
(2).进行模型评估,从模型评估的结果中分析该使用什么样的增强方法。
对原模型进行评估,评估代码如下,这里是修改了开源代码中validate部分做前向部分推理的代码:
with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) # compute output output_origin = model(images) output = output_origin loss = criterion(output, target) pred_list += output.cpu().numpy()[:, :2].tolist() target_list += target.cpu().numpy().tolist() # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5), i=i) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) # TODO: this should also be done with the ProgressMeter print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) name_list = val_loader.dataset.samples for idx in range(len(name_list)): name_list[idx] = name_list[idx][0] analyse(task_type='image_classification', save_path='./', pred_list=pred_list, label_list=target_list, name_list=name_list)
评估就是需要获得三个list,推理的直接的结果logits组合成的pred_list,存储的是每一张图片直接的预测结果,如[[8.725419998168945, 21.92235565185547]...[xxx, xxx]]。一个真实的label值组成的target_list,存储的是每一张图片的真实标签,如[0, 1, 0, 1, 1..., 1, 0]。还有原图像文件存储的路径组合成的name_list,如[xxx.jpg, ... xxx.jpg],这里是从pytorch度数据模块的类中通过val_loader.dataset.samples获取到后重新组合的。然后调用deep_moxing库中的analyse接口,在save_path下会生成一个model_analysis_results.json的文件,将这个文件上传到页面上任意一个训练任务的输出目录下,就能在页面的评估界面上看到对模型评估的结果:
这结果中我们需要分析模型的敏感度:
图中能看到,0类(正常类)随着图像清晰度的增大F1-score会提升,也就是说,模型在清晰的图片上,对正常类的检测表现更好,而在1类(瑕疵类)随着图像清晰读增大精度会下降,说明对模型而言,模糊的图片能让它检测有瑕疵类更加准确。由于该模型侧重于对瑕疵类的鉴别,所以可以使用图像模糊的手段作为TTA的增强方法。
2. 接下来我们可以看看在pytorch中,如何加入TTA。
pytorch的好处在于,可以直接获取到输入模型前的tensor并进行想要的操作。如在eval中
with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True)
这里拿到的images就是已经做好前处理的一个batch的图片数据。由步骤1中确定了两种增强方法,竖直方向的翻转和模糊。
pytorch中的翻转,在版本大于0.4.0时,可以用:
def flip(x, dim): indices = [slice(None)] * x.dim() indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) return x[tuple(indices)]
dim为模式,这里使用2为竖直方向的翻转,3为水平方向,1为做通道翻转。使用img_flip = flip(images, 2)就能得到竖直方向翻转的图片。
模糊稍多一些操作,可以利用cv2中自带的blur操作:
img = images.numpy() img[0] = cv2.blur(img[0], (3, 3)) images_blur = torch.from_numpy(img.copy())
3. 结果合成
我们现在就得到了三个输出,原图的推理结果origin_result,竖直方向翻转后得到的结果flip_output,模糊后得到的blur_output。那么该如何合成呢?
先看flip_output,一个想法是,原训练中见过的做过翻转的图片所占的比例是多少,在最终的输出一张做过翻转的图片对结果的贡献权重就是多少。那么相信很多有深度学习经验的同学们知道,一般模型做FLIP的概率为0.5,也就是模型见过的做过翻转的图片,大致比例上为0.5,那么flip的结果最最终结果的贡献就也是0.5,可得:
logits = 0.5*origin_result + 0.5*flip_result
此时,模型的精度结果为:
操作 | acc | norm类recall | abnorm类recall |
原版 | 89.13% | 97.2% | 71.3% |
flip结果合成 | 87.74% |
93.7% | 72.7% |
可以看到,虽然损失了norm类的精度,但是相对而言更重要的指标abnorm类的recall有提升。
然后分析blur_output,可以看到,位于最低的0-20%时,瑕疵类的精度是最高的,但是norm类的精度掉的太多,而且模糊本身就是提升abnorm类精度的,所以我们做一个折中,同样取blur图片的贡献值为0.5,可得公式:
logit = 0.5*origin_result + 0.5*blur_output
此时,模型的精度结果为:
操作 | acc | norm类recall | abnorm类recall |
原版 | 89.13% | 97.2% | 71.3% |
blur结果合成 | 88.117% |
94.8% | 73.3% |
可以看到,norm类的精度下降较多,abnorm类增长明显,与模型评估的分析结果一致。
综上,我们调整的结果虽然对norm类的损失较多导致整体精度下降,但是这是符合模型分析的结果的,我们需要的指标就是abnorm类recall的提升,而且可以看到,模型评估的结果要稍好于使用原版增强的合成结果。
小结
我们这里实验了两种使用test time augmatation的方法,一种是根据训练过程自带的增强方法来选择测试前增强,另一种是通过对模型进行敏感度分析,分析图片什么样的特征范围对于模型的判别最有帮助。当然,这里很重要的一点:TTA会增加模型推理的时间,对推理时延要求很高的人工智鞥你算法应用请仔细抉择选择合适的解决方法。
- 点赞
- 收藏
- 关注作者
评论(0)