ModelArts域适应算法EfficientMixGVB
域适应任务简单介绍
为什么我用在重庆收集的汽车道路数据集训练得到的车辆检测模型,在杭州道路上识别效果变得很差?
为什么我用真实图片训练得到的人物识别模型,无法识别动漫人物图像?
为什么我的算法训练效果这么好,部署成推理服务之后,实际使用效果变差了?
在迁移学习中, 当源域和目标的数据分布不同 ,但两个算法任务相同,这种特殊的迁移学习叫做域适应
以车辆检测模型为例,你在重庆收集汽车道路数据并精确标注后,训练了一个车辆检测模型,在重庆路上测试效果极好,但在杭州道路上就变得非常糟糕。重庆的道路高高低低,杭州道理相对平坦;重庆的出租车是黄色的,杭州的出租车是绿色的等,同样动漫人物中的人形一般较为夸张、抽象、色彩鲜艳,与真实人形差别较大,你训练的模型在这些场景下表现很差的原因是数据域发生了变化。
那怎么解决这类训练数据域与测试数据域变化很大的问题呢?在视觉领域中已经提出了许多域适应方法来减少训练数据域与测试数据域之间的差异。几乎所有域适应方法都是让模型同时训练有标注的训练数据和无标注的测试数据,包含了从模型结构出发改善两个域之间差异的方法,也有生成具有测试数据域style的训练数据的方法。
在改善数据域差异方面ModelArts已经推出了:无监督数据域迁移算法。
本文介绍ModelArts在图像分类领域中模型结构方面改善两个域之间差异的算法:EfficientMixGVB ,该算法在多个公开数据集上超越了现有的域适应算法。
ModelArts域适应算法:EfficientMixGVB
域适应算法流程
在域适应算法中
您需要准备两个数据集,一个是有标签的源域训练集,一个是无标签的目标域训练集 ;
域适应模型接受到两个数据集,使用域适应算法结构优化分类模型;
训练完成后,得到已经适应目标域数据的分类模型,可用于ModelArts在线推理;
算法使用
按照数据集创建文档,创建源域数据集和目标域数据集,其中源域数据集是有标注的,目标域是无标注数据集(即使有标注也忽略,如果部分图像有标注,建议导入源域数据集中,提高模型精度)。
在AI市场中选择EfficientMixGVB算法并订阅
创建EfficientMixGVB算法作业:选择算法管理中刚订阅的算法,点击创建训练作业
在创建训练作业界面选择源域目标域数据集,以及其他参数即可启动训练作业。
实验结果
在公开数据集office-31上的实验
包含了31类的数据,全部是Office的数据,数据来源为A(Amazon), W(Webcam) 和D(DSLR),
在几乎所有以office-31为实验数据的域适应相关论文中,目标域数据集既用于无监督训练也用于最后的模型测试,这样无法严谨地保证模型具有足够好的泛化性 。于是在本次实验中,我们将目标域数据集1:1随机切分为目标域数据集和测试集,目标域数据集用于无监督训练,测试集用于最终的精度测试。
域适应算法 | A->W_test | A->D_test | D->A_test | W->A_test | D->W_test | W->D_test | Average |
---|---|---|---|---|---|---|---|
Source only | 76.1 | 81.12 | 62.5 | 60.7 | 96.7 | 99.3 | 76.1 |
CAN | 94.5 | 95 | 78 | 77 | 99.1 | 99.8 | 90.6 |
SYM | 90.8 | 93.9 | 74.6 | 72.5 | 98.8 | 100 | 88.4 |
BNM | 91.5 | 90.3 | 70.9 | 98.5 | 98.5 | 100 | 87.1 |
GVB-GD | 92.59±0.879 | 90.86±1.123 | 73.76±0.783 | 71.85±0.365 | 98.52±0.125 | 99.97±0.032 | 87.9 |
ours(高性能) | 93.11±0.942 | 92.64±0.718 | 73.88±0.692 | 72.21±0.563 | 98.53±0.154 | 100.0±0.0 | 88.3 |
ours(高精度) | 96.66±0.325 | 95.87±0.393 | 79.46±0.225 | 77.18±0.255 | 98.74±0.03 | 99.65±0.06 | 91.6 |
在公开数据集visda2019上的实验
visda-2019图像分类-多源域适应竞赛 包含了6个域的数据集:clipart(剪贴画图像的集合)、infograph(具有特定对象的信息图图像)、painting(以绘画形式对物体的艺术描绘)、quickdraw(快速绘画)、real(照片和真实世界图像)、sketch(特定对象的草图),均包含了344个相同的类。下图为图像示例,每行代表了一个域,每列代表一个类别。
本次实验使用sketch(特定对象的草图)作为训练集(源域),将clipart(剪贴画图像的集合)中一半数据用于无标签目标域的训练,一半作为测试集。最终的测试结果如下:
域适应算法 | sketch->clipart_test |
---|---|
Source only | 44.46 |
CAN | 45.91 |
SYM | 56.45 |
CycleGAN | 46.3 |
GVB-GD | 59.26 |
ours | 60.08 |
ours | 62.33 |
在实际业务数据集lego上的实验
lego数据集来自一个机械臂项目,机械臂会判断当前桌上的lego属于哪个类别(4孔蓝色,2、3、4孔红色、3、4孔绿色,3、4孔橙色等),源域lego_1是实验室训练时自行采集的数据集,目标域lego_2是项目实际展示时场馆采集的数据集。
本次实验使用源域lego_1作为训练集,将目标域lego_2即用作无标签目标域的训练也作为测试集。最终的测试结果如下:
域适应算法 | lego_1->lego_2 |
---|---|
Source only | 33.4 |
CAN | 72.14 |
SYM | 85.21 |
BNM | 58.69 |
CycleGAN | 66.5 |
GVB-GD | 84.51±7.035 |
ours(高性能) | 88.95±5.228 |
ours(高精度) | 92.46±0.824 |
结合公开数据集visda2019上的实验,CAN算法只在office-31上表现较好,但是在其他数据集上表现均比较差,泛化性能存在问题。我们的EfficientMixGVB算法不管是高性能还是高精度变种,泛化性能均表现良好,精度相比其他算法优势也很大。
- 点赞
- 收藏
- 关注作者
评论(0)