训练一个数据不够多的数据集是什么体验?

举报
haha_y_c 发表于 2020/07/24 17:33:47 2020/07/24
【摘要】 前言 之前一段时间接触了几位用户提的问题,发现很多人在使用训练的时候,给的数据集寥寥无几,有一些甚至一类只有5张图片。modelarts平台虽然给出了每类5张图片就能训练的限制,但是这种限制对一个工业级的应用场景往往是远远不够的。所以联系了用户希望多增加一些图片,增加几千张图片训练。但是用户后面反馈,标注的工作量实在是太大了。我思忖了一下,分析了一下他应用的场景,做了一些策略变...

前言 

       前一段时间接触了几位用户提的问题,发现很多人在使用训练的时候,给的数据集寥寥无几,有一些甚至一类只有5张图片。modelarts平台虽然给出了每类5张图片就能训练的限制,但是这种限制对一个工业级的应用场景往往是远远不够的。所以联系了用户希望多增加一些图片,增加几千张图片训练。但是用户后面反馈,标注的工作量实在是太大了。我思忖了一下,分析了一下他应用的场景,做了一些策略变化。这里介绍其中一种带标签扩充数据集的方法。

数据集情况

       数据集由于属于用户数据,不能随便展示,这里用一个可以展示的开源数据集来替代。首先,这是一个分类的问题,需要检测出工业零件表面的瑕疵,判断是否为残次品,如下是样例图片:

       这是两块太阳能电板的表面,左侧是正常的,右侧是有残缺和残次现象的,我们需要用一个模型来区分这两类的图片,帮助定位哪些太阳能电板存在问题。左侧的正常样本754张,右侧的残次样本358张,验证集同样,正常样本754张,残次样本357张。总样本在2000张左右,对于一般工业要求的95%以上准确率模型而言属于一个非常小的样本。先直接拿这个数据集用Pytorch加载imagenet的resnet50模型训练了一把,整体精度ACC在86.06%左右,召回率正常类为97.3%,但非正常类为62.9%,还不能达到用户预期。

       当要求用户再多收集,至少扩充到万级的数据集的时候,用户提出,收集数据要经过处理,还要标注,很麻烦,问有没有其他的办法可以节省一些工作量。这可一下难倒了我,数据可是深度学习训练的灵魂,这可咋整啊。

       仔细思考了一阵子,想到modelarts上有智能标注然后人工校验的功能,就让用户先试着体验一下这个功能。我这边拿他给我的数据集想想办法。查了些资料,小样本学习few-shot fewshot learning (FSFSL)的常见方法,基本都是从两个方向入手。一是数据本身,二是从模型训练本身,也就是对图像提取的特征做文章。这里想着从数据本身入手。

       首先观察数据集,都是300*300的灰度图像,而且都已太阳能电板表面的正面俯视为整张图片。这属于预先处理的很好的图片。那么针对这种图片,翻转镜像对图片整体结构影响不大,所以我们首先可以做的就是flip操作,增加数据的多样性。flip效果如下:

       这样数据集就从1100张扩增到了2200张,还是不是很多,但是直接观察数据集已经没什么太好的扩充办法了。这时想到用Modelarts模型评估的功能来评估一下模型对数据的泛化能力。这里调用了提供的SDK:deep_moxing.model_analysis下面的analyse接口。

def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')
    pred_list = []
    target_list = []
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)
            # 获取logits输出结果pred和实际目标的结果target
            pred_list += output.cpu().numpy()[:, :2].tolist()
            target_list += target.cpu().numpy().tolist()
            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5), i=i)
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)
        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))
    # 获取图片的存储路径name
    name_list = val_loader.dataset.samples
    for idx in range(len(name_list)):
        name_list[idx] = name_list[idx][0]
    analyse(task_type='image_classification', save_path='/home/image_labeled/',
            pred_list=pred_list, label_list=target_list, name_list=name_list)
    return top1.avg

       上段代码大部分都是Pytorch训练ImageNet中的验证部分代码,需要获取三个list,模型pred直接结果logits、图片实际类别target和图片存储路径name。然后按如上的调用方法调用analyse接口,会在save_path的目录下生成一个json文件,放到Modelarts训练输出目录里,就能在评估结果里看到对模型的分析结果。我这里是线下生成的json文件再上传到线上看可视化结果。关于敏感度分析结果如下:

       这幅图的意思是,不同的特征值范围图片分别测试的精度是多少。比如亮度敏感度分析的第一项0%-20%,可以理解为,在图片亮度较低的场景下对与0类和其他亮度条件的图片相比,精度要低很多。整体来看,主要是为了检测1类,1类在图片的亮度和清晰度两项上显得都很敏感,也就是模型不能很好地处理图片的这两项特征变化的图片。那这不就是我要扩增数据集的方向吗?

       同时,ModelArts平台还提供了使用数据扩增的接口可以直接扩充数据集:

       好的,那么我就试着直接对全量的数据集做了扩增,得到一个正常类2210张,瑕疵类1174张图片的数据集,用同样的策略扔进pytorch中训练,得到的结果:


方法 Accurancy recall norm类 recall abnorm类
原版
86.06% 97.3% 62.9%
从1100张扩增到2940张 86.31% 97.6% 62.5%

怎么回事,和设想的不太一样啊。。。

       重新分析一下数据集,我突然想到,这种工业类的数据集往往都存在一个样本不均匀的问题,这里虽然接近2:1,但是检测的要求针对有瑕疵的类别的比较高,应该让模型倾向于有瑕疵类去学习,而且看到1类的也就是有瑕疵类的结果比较敏感,所以其实还是存在样本不均衡的情况。由此后面的这两种增强方法只针对了1类也就是有问题的破损类做,最终得到3000张左右,1508张正常类图片,1432张有瑕疵类图片,这样样本就相对平衡了。用同样的策略扔进resnet50中训练。最终得到的精度信息:

方法 Accurancy recall norm类 recall abnorm类
原版
86.06% 97.3% 62.9%
从1100张扩增到2940张 89.13% 97.2% 71.3%

       可以看到,同样在验证集,正常样本754张,残次样本357张的样本上,Acc1的精度整体提升了接近3%,重要指标残次类的recall提升了8.4%!嗯,很不错。所以直接扩充数据集的方法很有效,而且结合模型评估能让我参考哪些扩增的方法是有意义的。当然还有很重要的一点,要排除原始数据集存在的问题,比如这里存在的样本不均衡问题,具体情况具体分析,这个扩增的方法就会变得简单实用。

       之后基于这个实验的结果和数据集。给帮助用户改了一些训练策略,换了个更厉害的网络,就达到了用户的要求,当然这都是定制化分析的结果,这里不详细展开说明了,或者会在以后的博客中更新。



ModelArts数据处理相关博客:

1.数据处理简介:https://bbs.huaweicloud.com/blogs/193413
2.数据生成域迁移:https://bbs.huaweicloud.com/blogs/193405数据风格变换:ModelArts的数据域迁移功能
3.数据校验:https://bbs.huaweicloud.com/blogs/193412 数据校验--给你的数据做个体检吧
4.数据去重:https://bbs.huaweicloud.com/blogs/193420数据去重---ModelArts在数据处理上的应用技巧-免费,欢迎大家体验
5.数据清洗:https://bbs.huaweicloud.com/blogs/193421数据清洗---ModelArts在数据处理上的应用技巧-免费,欢迎大家体验
6.难例筛选:https://bbs.huaweicloud.com/blogs/193422如何加速AI模型迭代:Modelarts的难例筛选功能




引用数据集来自:


Buerhop-Lutz, C.; Deitsch, S.; Maier, A.; Gallwitz, F.; Berger, S.; Doll, B.; Hauch, J.; Camus, C. & Brabec, C. J. A Benchmark for Visual Identification of Defective Solar Cells in Electroluminescence Imagery. European PV Solar Energy Conference and Exhibition (EU PVSEC), 2018. DOI: 10.4229/35thEUPVSEC20182018-5CV.3.15

Deitsch, S.; Buerhop-Lutz, C.; Maier, A. K.; Gallwitz, F. & Riess, C. Segmentation of Photovoltaic Module Cells in Electroluminescence Images. CoRR, 2018, abs/1806.06530

Deitsch, S.; Christlein, V.; Berger, S.; Buerhop-Lutz, C.; Maier, A.; Gallwitz, F. & Riess, C. Automatic classification of defective photovoltaic module cells in electroluminescence images. Solar Energy, Elsevier BV, 2019, 185, 455-468. DOI: 10.1016/j.solener.2019.02.067

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

举报
请填写举报理由
0/200