小样本学习总结(一)

举报
Deep_Rookie 发表于 2020/10/22 20:25:22 2020/10/22
【摘要】 这里有四张图片,两张是犰狳、两张是穿山甲,在给定一张query图片,我们很快就可以分辨出这张图片所属的类别。我们仅用了四张图片就完成了一个学习任务。但是我们不可能仅通过这四张图片完成一个深度神经网络的训练。这其实就是做few shot learning 的motivation。 传统的监督学习 通过训练集训练模型,推理阶段 待分类的样本虽然之前没有见过,但是其来源于已知的类,其类别包含在训练...

这里有四张图片,两张是犰狳、两张是穿山甲,在给定一张query图片,我们很快就可以分辨出这张图片所属的类别。我们仅用了四张图片就完成了一个学习任务。但是我们不可能仅通过这四张图片完成一个深度神经网络的训练。这其实就是做few shot learning 的motivation。

 image.png

传统的监督学习 通过训练集训练模型,推理阶段 待分类的样本虽然之前没有见过,但是其来源于已知的类,其类别包含在训练集中

 image.png

Few shot learning的目标不是为了让机器识别训练集里的图片 并泛化到测试集。其目标是为了让机器学会学习,aka learn to learn。其学习的目标是为了让模型理解事物的异同,学会区分不同的事物,而不是区分某个指定类别的能力。

小样本学习在推理阶段,其query样本来自于未知的类别,其类别并不包含在训练集中,需要使用小样本构建新的support set,对query样本进行分类

image.png

目前主流的小样本学习的主要方法是基于元学习的框架

在第一个阶段,首先是通过包含较多样本的base set,构建若干task 或者 称为episode ,这些episode包含support set和query set ,support set包含N类,这N个类从base set中随机选择。每类包含K个样本。因此Few shot learning 或者 meta learning 也称为 N WAY K SHOT 问题。

通过支撑集和查询集构建的episode,学习这个分类模型,通过多个episode反复训练模型。在testing 阶段使用小样本构建新的support set,得到新的分类模型,通过该模型对query样本进行分类

image.png



Match Network


给定support set S 以及待测试样本 xhat,其预期输出是一个加权求和的形式

其中a函数表示如下,其实质是对余弦距离的softmax,a中对于支撑样本和查询样本的Embedding函数是不同的,通过C()函数来计算两个Embedding的余弦距离

支撑样本的Embedding是g,是基于双向LSTM来学习的,每个支撑样本的Embedding是其他支撑集是相关的

测试样本的Embedding是一个基于注意力机制的LSTM,其中的f‘是一个encoder,提取图像的特征,可以是任意的一种神经网络,每个测试样本的Embedding和支撑集是有关的。

文中称为 Full Context Embedding,意思就是说,在计算 embedding 的时候,也要用到 S,而不是简单将数据输入到一个网络模块中然后得到一个表示向量。

在预测时候也是支持集中的样本通过G,测试样本通过F,他们通过转换之后 预测函数求得最终的预测值。

image.png


Prototypical networks

原型网络的IDEA:每个类别都存在其原型表达,该表达就是支撑集在Embedding空间的均值(也就是说会存在一个Embedding space,在这个space中每一类的样本都趋向于比较靠近,这类样本的均值就被认为是其原型表达)。在这种思想的指导下,问题就由原来的分类问题变成了在Embedding空间中的最近邻问题。当输入一个待测样本的时候,通过计算该样本和每一类原型表达的欧式距离即可得到其类别。


Relation network

Idea :之前的工作都将中心放在了学习一个好的数据表达上,即怎样学习一个好的Embedding的网络,然后使用已有的度量方法来进行分类,这个度量方法可能就是根据一些任务或者先验知识来确定的。那么如果距离度量也可以通过训练得到,就可以得到泛化能力更好的模型。

因此该方法除了设计一个Embedding外,还设计了一个打分的网络,就是来对样本之间,他们是不是属于一类来计算一个打分,来度量他们之间的相似程度。

其实可以看到Relation network跟match network 的结构很像,不过对于support和query使用的是同一个embedding。通过这个embedding之后,黄色的表示测试样本,前面的块是支撑集中的样本embedding后的特征映射。relation 模块用一个拼接函数来将 support sample 和 query sample 的 feature map 拼接起来,然后度量网络通过一个g函数对这些特征映射进行融合并计算匹配分数,在转化为one hot的形式 通过监督学习来同时学习embedding函数F和这个打分网络G

image.png

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。