ModelArts域适应算法EfficientMixGVB

举报
T_c_D 发表于 2020/08/31 17:34:18 2020/08/31
【摘要】 本文介绍ModelArts在图像分类领域中改善目标域与源域之间数据分布差异的算法:EfficientMixGVB ,该算法在多个公开数据集上超越了现有的域适应算法。


域适应任务简单介绍

为什么我用在重庆收集的汽车道路数据集训练得到的车辆检测模型,在杭州道路上识别效果变得很差?

为什么我用真实图片训练得到的人物识别模型,无法识别动漫人物图像?

为什么我的算法训练效果这么好,部署成推理服务之后,实际使用效果变差了?

在迁移学习中, 当源域和目标的数据分布不同 ,但两个算法任务相同,这种特殊的迁移学习叫做域适应 (Domain Adaptation,DA )。几乎所有算法在落地实际场景时都会遇到域适应问题。因为带有标注的训练数据集是很容易获得的,我们训练模型使用的都是这些数据,但是将模型应用到实际场景中的数据来源往往是不同的且没有标注的。推理时的数据域与训练时的数据域分布差异很大,就可能导致模型推理效果变差。

以车辆检测模型为例,你在重庆收集汽车道路数据并精确标注后,训练了一个车辆检测模型,在重庆路上测试效果极好,但在杭州道路上就变得非常糟糕。重庆的道路高高低低,杭州道理相对平坦;重庆的出租车是黄色的,杭州的出租车是绿色的等,同样动漫人物中的人形一般较为夸张、抽象、色彩鲜艳,与真实人形差别较大,你训练的模型在这些场景下表现很差的原因是数据域发生了变化。

那怎么解决这类训练数据域与测试数据域变化很大的问题呢?在视觉领域中已经提出了许多域适应方法来减少训练数据域与测试数据域之间的差异。几乎所有域适应方法都是让模型同时训练有标注的训练数据和无标注的测试数据,包含了从模型结构出发改善两个域之间差异的方法,也有生成具有测试数据域style的训练数据的方法。

在改善数据域差异方面ModelArts已经推出了:无监督数据域迁移算法

本文介绍ModelArts在图像分类领域中模型结构方面改善两个域之间差异的算法:EfficientMixGVB ,该算法在多个公开数据集上超越了现有的域适应算法。


ModelArts域适应算法:EfficientMixGVB

域适应算法流程

在域适应算法中

  • 您需要准备两个数据集,一个是有标签的源域训练集,一个是无标签的目标域训练集

  • 域适应模型接受到两个数据集,使用域适应算法结构优化分类模型;

  • 训练完成后,得到已经适应目标域数据的分类模型,可用于ModelArts在线推理;

算法使用

按照数据集创建文档,创建源域数据集和目标域数据集,其中源域数据集是有标注的,目标域是无标注数据集(即使有标注也忽略,如果部分图像有标注,建议导入源域数据集中,提高模型精度)。

在AI市场中选择EfficientMixGVB算法并订阅

创建EfficientMixGVB算法作业:选择算法管理中刚订阅的算法,点击创建训练作业

在创建训练作业界面选择源域目标域数据集,以及其他参数即可启动训练作业。

实验结果

在公开数据集office-31上的实验

包含了31类的数据,全部是Office的数据,数据来源为A(Amazon), W(Webcam) 和D(DSLR),


在几乎所有以office-31为实验数据的域适应相关论文中,目标域数据集既用于无监督训练也用于最后的模型测试,这样无法严谨地保证模型具有足够好的泛化性 。于是在本次实验中,我们将目标域数据集1:1随机切分为目标域数据集和测试集,目标域数据集用于无监督训练,测试集用于最终的精度测试。

域适应算法 A->W_test A->D_test D->A_test W->A_test D->W_test W->D_test Average
Source only 76.1 81.12 62.5 60.7 96.7 99.3 76.1
CAN 94.5 95 78 77 99.1 99.8 90.6
SYM 90.8 93.9 74.6 72.5 98.8 100 88.4
BNM 91.5 90.3 70.9 98.5 98.5 100 87.1
GVB-GD 92.59±0.879 90.86±1.123 73.76±0.783 71.85±0.365 98.52±0.125 99.97±0.032 87.9
ours(高性能) 93.11±0.942 92.64±0.718 73.88±0.692 72.21±0.563 98.53±0.154 100.0±0.0 88.3
ours(高精度) 96.66±0.325 95.87±0.393 79.46±0.225 77.18±0.255 98.74±0.03 99.65±0.06 91.6


在公开数据集visda2019上的实验

visda-2019图像分类-多源域适应竞赛 包含了6个域的数据集:clipart(剪贴画图像的集合)、infograph(具有特定对象的信息图图像)、painting(以绘画形式对物体的艺术描绘)、quickdraw(快速绘画)、real(照片和真实世界图像)、sketch(特定对象的草图),均包含了344个相同的类。下图为图像示例,每行代表了一个域,每列代表一个类别。


本次实验使用sketch(特定对象的草图)作为训练集(源域),将clipart(剪贴画图像的集合)中一半数据用于无标签目标域的训练,一半作为测试集。最终的测试结果如下:

域适应算法 sketch->clipart_test
Source only 44.46
CAN 45.91
SYM 56.45
CycleGAN 46.3
GVB-GD 59.26
ours 60.08
ours 62.33


在实际业务数据集lego上的实验

lego数据集来自一个机械臂项目,机械臂会判断当前桌上的lego属于哪个类别(4孔蓝色,2、3、4孔红色、3、4孔绿色,3、4孔橙色等),源域lego_1是实验室训练时自行采集的数据集,目标域lego_2是项目实际展示时场馆采集的数据集。

本次实验使用源域lego_1作为训练集,将目标域lego_2即用作无标签目标域的训练也作为测试集。最终的测试结果如下:

域适应算法 lego_1->lego_2
Source only 33.4
CAN 72.14
SYM 85.21
BNM 58.69
CycleGAN 66.5
GVB-GD 84.51±7.035
ours(高性能) 88.95±5.228
ours(高精度) 92.46±0.824

结合公开数据集visda2019上的实验,CAN算法只在office-31上表现较好,但是在其他数据集上表现均比较差,泛化性能存在问题。我们的EfficientMixGVB算法不管是高性能还是高精度变种,泛化性能均表现良好,精度相比其他算法优势也很大。

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。