联邦学习专栏丨数据非独立同分布挑战及方案

举报
就挺突然 发表于 2021/01/05 11:05:40 2021/01/05
【摘要】 作者:李新春声明:本文首发于华为NAIE《网络人工智能园地》微信公众号,如有转载,请注明出处。01NonllD挑战概述在传统机器学习中,数据分布在同一个机器上,并且假设数据是从同一个分布中独立地采样的,即数据独立同分布(Independently Identically Distribution, IID)。后来随着数据量的急剧暴长,大数据时代需要分布式算法进行并行计算,此时分布式优化(Di...

作者:新春

声明:本文首发于华为NAIE《网络人工智能园地》微信公众号,如有转载,请注明出处。

01


NonllD挑战概述


图片

在传统机器学习中,数据分布在同一个机器上,并且假设数据是从同一个分布中独立地采样的,即数据独立同分布(Independently Identically Distribution, IID)。

后来随着数据量的急剧暴长,大数据时代需要分布式算法进行并行计算,此时分布式优化(Distributed Optimization)主要做的事情是利用多台机器对大数据或者大模型进行并行计算,包括模型并行(Model Parallelism)和数据并行(Data Parallelism)两种基本范式。但此时仍只是为了加速计算,各个客户端上的数据仍然是IID的。

但是随着物联网(Internet of Things, IoT)设备的大规模使用,如何协同数以万计的设备及其私有数据进行训练模型则是联邦学习做的事情,此时由于设备归属于某个用户、企业、场景,因此其数据分布往往是差异极其大的,即数据非独立同分布(Non-IID)。第一,非同分布很容易理解,就是因为数据分布差异大;其次,由于受到用户群体、地域关联等因素,这些设备的数据分布往往又是有关联的,即非独立。


02


FedShare-WeightDivergence-Zhao-CoRR2018

图片

图片

本节介绍一篇研究工作《Federated Learning with Non-IID Data》,改篇文章是较早对联邦学习中NonIID进行分析的文章,主要包括以下几个方面的内容:


1/经典联邦学习算法FedAvg在NonIID场景下性能会很差;

2/定义Weight Divergence,并指明它是导致NonIID下FedAvg的性能下3/降主要原因;

4/理论推导Weight Divergence,分析了和其相关的因素(比如,学习率等);

5/提出了一种共享部分公有数据的方法,可以提升NonIID下FedAvg的性能。



首先,论文给出NonIID数据的构造方法,文章在Mnist、Cifar10和KWS数据集。Mnist是手写数字集合分类,有10类,分别为数字0-9;Cifar10是10分类的图像识别任务;KWS是Keyword Spotting任务,即从语音数据中识别出关键词,本质上也可以理解为分类任务,有10个关键词,即10类。总的来说,这三个数据集都是10分类数据。

假设训练数据有 图片 个,其中 图片 是第i个类的样本数量。构造IID的联邦学习设置时,将所有数据均匀地分布在K个客户端,每个客户端上都有C个类别的数据,且第k个客户端上拥有 图片 个第i类样本,这样就保证每个客户端上的类别分布都相同;构造NonIID设置时,根据NonIID程度的不同,分为1-Class NonIID和2-Class NonIID,前者指的是每个客户端上仅有一个类别的数据,后者指的是每个客户端上有且仅有两个类的数据。当 图片 时,1-Class NonIID的设置就是将数据按照类别分配到各个客户端,2-Class NonIID指的是将原始数据分成20份,每份包含某个类别所有数据的一半,然后将这20份按照 图片 类别组合每2份放到一个客户端,其中 图片 。

可以理解为,这里的NonIID设置主要是从同一个数据集中根据样本类别进行划分,客户端之间主要是 图片 发生了变化,而 图片 一致,这种NonIID被称为标记分布漂移(Label Distribution Skew)。不同程度的NonIID主要取决于每个客户端所见到类别的数量,每个客户端上面只有一个类别的数据自然是最NonIID的情形,每个客户端上有所有类别的数据就是IID的设置。

在介绍实验结果之前,简单回顾一下FedAvg的流程。

1

客户端局部训练:客户端从服务器下载模型;客户端在自己的本地数据上更新E轮,即训练E遍数据集;客户端将更新后的模型发送给服务器;

2

服务器全局聚合:服务器接收来自客户端的模型,采用简单加权平均的方法聚合模型,权重根据客户端上样本数目决定;

3

迭代:以上两个过程迭代Max Round次。


这个里面有一些超参数,第一个是客户端局部训练和服务器全局聚合迭代的次数,也称为通信次数(Communication Rounds),第二个是局部训练的轮数E,第三个是局部训练使用的B,即Batch Size。

一般来说,越NonIID的数据,使用的局部训练轮数E越大,那个客户端件模型的差异会越大,因此最后聚合的模型性能就不会太好。

下面的图是实验结果:

图片

其中三幅图分别是三个数据集上的结果。


1、B=1000 SGD指的是所有数据放在一起使用Batch Size=100 * 10进行训练,当做一个性能上界;


2、B=100 E=1 IID指的是IID情况下,使用Batch Size=100,局部迭代1轮的性能,可以看到这个曲线基本上和B=1000 SGD差不多,只是在Cifar10上性能略低一点;


4、B=100 E=1 NonIID(2)指的是2-Class NonIID情况下,Batch Size=100,局部迭代一轮的性能,可以看到会比IID情形差,但是会比下面的NonIID(1)高;


5、B=100 E=1,5 NonIID(1)指的是1-Class NonIID情况下,Batch Size=100,局部迭代1轮或者5轮的时候的模型性能,从这个图中可以看出E=1或者5对性能影响不大,但是会比IID和NonIID(2)低很多。但是此时,Mnist上仍然可以达到90%+的准确率,Cifar10上达到40%+的准确率(也没有像随机猜测那么差)。



为什么说NonIID情况下聚合的模型性能会差呢?下面的示意图展示了联邦学习任务中梯度更新的一些过程:

图片

左边图是IID情形,右边图是NonIID情形。这里面主要比较了使用SGD在所有数据上更新以及使用FedAvg更新的过程。其中 图片 指的是使用SGD进行数据中心化(Data Centralized)更新的权重变化,也就是在所有数据上进行更新; 图片 指的是客户端1-K上的权重变化; 图片 是图片聚合之后的结果,也就是FedAvg中服务器全局聚合之后的权重结果。

可以看出,在IID设置下,一开始 图片 是在第m-1个全局Round之后全局模型聚合的结果,然后下发到各个客户端,各个客户端进行根据样本更新,由于SGD是随机梯度下降,由于采样Batch数据的不确定性,各个客户端的更新并不是完全同向的,更新了T步之后,客户端1上得到了权重 图片 (亮蓝色),客户端K上得到了 图片 (橙色),最后服务器上聚合的结果是图片,可以看到和图片离得比较近。然而在NonIID情形里面,各个客户端上权重的更新方向差异很大,导致最后FedAvg聚合的结果是图片,和SGD的权重图片离得比较远。

FedAvg聚合得到的权重和SGD更新得到的权重的差异被定义为Weight Divergence,该项越小则说明FedAvg聚合的结果和SGD在所有数据上更新的结果接近。具体定义为:

图片


然后文章就比较了IID和NonIID情形下神经网络每层的Weight Divergence,神经网络的训练使用相同的初始化:

图片

从上图可以看出,2-Class NonIID情形下的Weight Divergence (绿色)会比1-Class NonIID和IID情形大很多。

后来文章利用理论推导证明第m次通信后的Weight Divergence和很多项因素有关,比如第m-1次的Weight Divergence,学习率,概率分布距离 图片 等。这个部分放到最后的Theory中介绍。

总的来说,Weight Divergence可以根据上述距离进行简单的近似。上面的距离是一种简化版的EMD,用来衡量两个分布的距离:

图片

首先,文章假设 图片 是均匀分布的且 图片 ,那么为了生成每个客户端具有相同EMD的分布就很容易。首先,给定EMD,从C-Simplex中采样一个分布 图片 ,满足 图片 和和均匀分布的EMD距离为给定的EMD,这样的 图片 有很多,采样多次,每次对 图片 的元素进行Shift可以依次得到 图片 ,前提是 图片 ,然后每个客户端上根据相应的分布采样数据即可构造一个具有相应EMD的NonIID场景。

图片

这边展示的是对于固定的EMD(横轴),多次采样 图片 构造NonIID设置得到的Weight Divergence的平均值和方差,可以看出EMD的大小和Weight Divergence的相关性很强。

最后,文章提出了一种共享数据的方法对NonIID下FedAvg进行改进,记为FedShare。FedShare的做法是留出一个数据集G,该数据集是均匀分布的,其大小 图片 和所有数据的总数 图片 的权比例为图片 。首先,该数据集被用来预训练全局模型,然后下发的时候也会下发一部分全局数据,数目为 图片 ,到各个客户端,客户端同时在这部分全局数据和私有数据上训练。

FedShare的示意图和其性能图如下, 图片 为NonIID下FedAvg的性能,但是随着 图片 增大,性能逐渐变好。在共享所有数据5%的情况下,Cifar10上可以提升将近30%的性能。

图片




03


NonIID-Quagmire-Hsieh-CoRR2019


图片

图片

该文章《The Non-IID Data Quagmire of Decentralized Machine Learning》对联邦学习中的NonIID数据困境(Quagmire)进行了梳理介绍。

其主要内容包括:


1、调研了大量的实验,发现多种分布式算法(包括FedAvg)会在NonIID情形下失败;

2、发现基于Batch Normalization的方法更容易失败,基于Group Normalization的会好很多;

3、提出了一种调节通信频次的方法SkewScout,可以根据客户端上分布偏移的程度调整相应的通信频率。



文章对NonIID的定义是:比如,20% Non-IID指的是对原始数据的20%根据类别划分到各个客户端,其余80%的数据随机分配。至于根据类别划分的过程则依据于任务,比如Cifar10上有10个类别,5个客户端的话则每个客户端上有两个类的数据。

文章调研了三种算法,Gaia、FedAvg、DGC。其中Gaia是一种根据局部更新量是否显著进行决定是否要上传至服务器的算法,FedAvg主要是采用多轮本地迭代减少通信次数,DGC则是通过各种压缩方法(量化、梯度截断等)进行减少通信量。BSP(Bulk Synchronous Parallel)通过每次迭代之后都进行模型传输,当做性能的Baseline。

图片

上图展示的是Cifar10上的结果,使用AlexNet、GoogLeNet、LeNet、ResNet20时,NonIID场景下FedAvg的性能下降分别有16%,15%,67%,56%。

下图展示随着局部更新次数的增加,FedAvg性能的变化:

图片

可以看出,随着本地迭代次数的增加,IID下的准确率具有稍微上升的趋势,但是Non-IID上的性能急剧下降。这里文章的解释是通过FedAvg中全局模型聚合之后的更新量进行解释的,即Average Local Update Delta:

图片

下面是IID和NonIID场景下,模型在25个通信迭代中改变幅度的大小比较,可以看到NonIID的幅度会大于IID场景:

图片

文章还在ImageNet数据集、人脸识别数据集上验证了NonIID下模型性能急剧下降的事实,这里不罗列所有结果。

Batch Normalization可以使得优化目标更平坦,以及使用大的学习率,减少了之前使用大学习率会震荡的问题。Batch Normalization的基本做法是对一批样本 图片 求平均值和方差的统计量,记为 图片 ,然后对这批样本做标准化,即 图片 ,有的时候还会加入 图片 进行缩放和偏置,其中 图片 是可以学习的参数。而 图片 是通过数据统计出来的,如果Batch Size太小或者不同客户端数据是NonIID的,那么不同客户端上的这些统计量就会差异特别大。

下图统计了NonIID情形下两个客户端的 图片 ,计算了在在LeNet第一层不同通道上的均值Divergence:


图片

从上图可以看出,NonIID下(虚线)的均值Divergence会很大,并且是基本上每个Channel的差异都很大。

因此,文章探寻了几种可能的Normalization的方法,比如:Weight Normalization、Layer Normalization、Weight Renormalization、Group Normalization。关于这几种Normalization的细节和区别这里不再介绍,但是实验发现Group Normalization在NonIID下性能降低的最少,是最理想的Normalization方法。

图片

文章总结:在NonIID设置下,好的Normalization需要满足两个条件:1)不依赖于Batch计算,因为每个Batch的采样和客户端的数据分布极其相关;2)性能比较好。

关于上面的第一点存在疑问:Group Normalization不依赖于Batch计算,但是其会对每个样本的部分通道计算统计量,这部分统计量是和样本直接相关的。和单个样本直接相关或者和一个Batch的统计量直接相关不都会受分布差异影响的吗?因此对于Batch Normalization效果不是很好,但是Group Normalization的性能比较好的真实原因尚待探讨

后续文章测试了不同NonIID程度下的性能:

图片

可以看出随着NonIID程度的加深,FedAvg的性能一直在下降。

最后是SkewScout算法,文章从系统设计层面上设计了一个可以调控通信频率的算法:

图片


主要包括三个部分:第一,模型从客户端0发送到客户端1;第二,统计客户端1上性能的变化,当然不一定是Accuracy Loss,对比的是使用接收的模型的性能和原本自身模型性能的差异;第三,根据性能变化调控通信的一些参数。

总之,文章对联邦学习里面的NonIID实验进行了系统的介绍,个人感觉不足的地方是调研的联邦学习算法不够充分,以及个人感觉文章比较有意思的地方是各种Normalization在NonIID下的性能表现,以及其背后的原理(原理尚待探讨,文章只是给出了猜测和经验性结论)。



04


NonIID Theory


图片


下面的就比较硬核了,是NonIID下关于Weight Divergence的理论,先贴上理论的结果:


有 K 个客户端,每个客户端上有 图片 个样本采自于分布 图片 , 图片 。记 图片 是 图片 -Lipschitz的, 图片 代表的是每个类别。在FedAvg中,每个客户端更新 T 个Batch,然后聚合。第 m 次聚合的Weight Divergence有下面的不等式:
图片




下面一一介绍其推导过程。

定义损失函数为,分类的交叉熵损失:

图片其中 图片 为分类预测概率的第i项,上面式子只是将样本期望和类别求和项颠倒了一下次序,其中 图片 定义了模型输出的第i类概率的对数几率的期望,只和样本类别 i 和参数 图片 有关,因此将其记为 图片 ,其关于 图片 的梯度为 图片 ,即:


图片

因此根据梯度更新公式有:

图片

其中 图片 为Centralized训练时使用SGD的第c步权重,上标c代指Centralized。

同理,对于FedAvg中的第k个客户端,其优化的目标应该是:

图片

和SGD不同的是,主要是将分布 图片 换成了 图片 ,即Centralized训练优化的分布是根据全局分布 图片 进行优化的,而FedAvg中每个客户端的优化的分布是 图片 。那么第k个客户端的梯度更新为:

图片

然后是FedAvg的聚合过程,简单加权平均:

图片

其中 图片 是所有样本数目。

记 图片 为FedAvg第m次全局聚合后的权重, 图片 为第 m-1 次全局聚合后的权重,期间相差了 T 步局部更新的次数,即每个客户端会下载下来 图片 ,然后在上面更新 T 步,依次得到: 图片 ,最后所有的 图片 聚合得到 图片 。

记 图片 为使用SGD在Centralized的数据上第m * T步的结果, 图片 为其对应的更新历史轨迹。

接下来就是计算Weight Divergence,即 图片 。下面展示一下推导过程,这一结果记为Stage1:

图片

其中,第一个等号是直接利用FedAvg参数聚合替换得到的;第二个等号是将mT-1步到mT步的梯度更新替换得到的;第三个等号比较有意思,利用了 图片 的事实,以及 图片 的假设,将 图片 拆分为多个客户端上的分布的加权平均;第四个步骤(第一个不等号)利用了 图片 的不等式,以及 图片 的不等式;最后一步(第二个不等号)利用了Lipschitz的性质: 图片 。

上面只是将 图片 的关系推导到了 图片 ,还需要进一步推导至 图片 ,直到 图片 。

下一步,这一步骤记作Stage2:

图片

其中,第一个等号是将梯度更新步骤代入;第二个等号是重新组合了一下;第三个等号通过 图片 作为中介将 图片 和 图片 联系了起来,这是推导收敛性中常见的操作;最后一个不等号利用三角不等式和Lipschitz性质,以及 图片 的定义: 图片 ,和 图片 的定义: 图片 。

依次递推可以得到,这一步记作Stage3:

图片

将这一个式子代入Stage1的结果即可证明最开始给出的Theory。

总结一下,一般来说,联邦学习延承自分布式优化,对各种梯度下降、分布式SGD、异步SGD、带有动量的SGD等的收敛性证明要求比较高,因此大多数也会涉及很多理论证明。这里展示的理论只是初步的简单推导,就已经比较复杂了,因此对数学功底要求比较高。主要是符号太复杂,推着推着估计就乱掉了,作为入门,这个理论还是很适合的。





05


总结


图片


本文对两篇关于联邦学习中NonIID的文章进行了介绍,包括NonIID场景的构造、算法性能和需要解决的问题等等。最后展示了一个NonIID理论推导过程。


参考文献

  • Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, Vikas Chandra: Federated Learning with Non-IID Data. CoRR abs/1806.00582 (2018)

  • Kevin Hsieh, Amar Phanishayee, Onur Mutlu, Phillip B. Gibbons: The Non-IID Data Quagmire of Decentralized Machine Learning. CoRR abs/1910.00189 (2019)

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。