联邦学习算法之一ModelArts “pytorch_fedamp_emnist_classification”学习(二)
【摘要】 上周,我们已经在ModelAts上实现了pytorch_fedamp_emnist_classification的环境配置,对样本数据结构以及pytorch_fedamp_emnist_classification需要的数据结构进行了简单地探索。让我们开始这周的学习。。。1. 数据探索先让我们对数据进行一些更加详细地探索。1.1. 训练样本集dfl.load_file_binary(data...
上周,我们已经在ModelAts上实现了pytorch_fedamp_emnist_classification的环境配置,对样本数据结构以及pytorch_fedamp_emnist_classification需要的数据结构进行了简单地探索。让我们开始这周的学习。。。
1. 数据探索
先让我们对数据进行一些更加详细地探索。
1.1. 训练样本集
dfl.load_file_binary(data_filename=train_sample_filename + str(1) + filename_sx,
label_filename=train_label_filename + str(1) + filename_sx)
dfl.data.shape
(1000, 28, 28)
可以看到1个样本集中共有1000个样本,每个样本是一个28*28的二维列表,那么每个样本是怎么样的呢?
dfl.data[1] #可以多看几个样本
array([[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0.00392157, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0.01960784, 0.12941176,
0.26666667, 0.11764706, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0.17647059, 0.74509804,
0.87843137, 0.43921569, 0.01568627, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0.01568627, 0.45098039, 0.96078431,
0.95294118, 0.44313725, 0.01568627, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0.08627451, 0.6745098 , 0.97647059,
0.68627451, 0.13333333, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0.00784314, 0.32156863, 0.90980392, 0.91372549,
0.32941176, 0.01176471, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0.01568627, 0.49019608, 0.97647059, 0.85098039,
0.15294118, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0.01568627, 0.49803922, 0.98039216, 0.85098039,
0.14509804, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.00784314, 0.01568627, 0.01568627, 0.01568627, 0.00784314,
0. , 0. , 0. ],
[0. , 0.01568627, 0.49803922, 0.98039216, 0.85098039,
0.14509804, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.01568627, 0.07843137, 0.10588235,
0.30588235, 0.49019608, 0.49803922, 0.49019608, 0.30196078,
0.03137255, 0. , 0. ],
[0. , 0.01568627, 0.49803922, 0.98039216, 0.85098039,
0.14509804, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.01176471, 0.13333333, 0.45098039, 0.66666667, 0.72156863,
0.87058824, 0.97647059, 0.98039216, 0.97647059, 0.85490196,
0.30196078, 0.00784314, 0. ],
[0. , 0.01568627, 0.44705882, 0.96078431, 0.85098039,
0.15294118, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.03137255, 0.13333333,
0.32941176, 0.68627451, 0.95294118, 0.98823529, 0.99215686,
0.99607843, 0.99607843, 1. , 0.99607843, 0.97254902,
0.48627451, 0.01568627, 0. ],
[0. , 0. , 0.19607843, 0.86666667, 0.91764706,
0.37647059, 0.03921569, 0. , 0. , 0. ,
0.01568627, 0.03529412, 0.18039216, 0.49803922, 0.8 ,
0.91372549, 0.98431373, 0.99607843, 0.99607843, 0.99215686,
0.99215686, 0.99607843, 0.99607843, 0.99607843, 0.8627451 ,
0.30196078, 0.00784314, 0. ],
[0. , 0. , 0.1254902 , 0.79607843, 0.98823529,
0.86666667, 0.50196078, 0.2 , 0.14509804, 0.2 ,
0.44705882, 0.54901961, 0.81568627, 0.96470588, 0.99607843,
0.99607843, 0.99607843, 0.98431373, 0.87058824, 0.72156863,
0.74901961, 0.93333333, 0.99607843, 0.90980392, 0.48235294,
0.03921569, 0. , 0. ],
[0. , 0. , 0.03137255, 0.49803922, 0.96078431,
0.99607843, 0.96470588, 0.87058824, 0.85098039, 0.87058824,
0.96078431, 0.98039216, 0.99607843, 0.99607843, 0.99607843,
0.96470588, 0.8627451 , 0.66666667, 0.30980392, 0.11764706,
0.40784314, 0.91764706, 0.94901961, 0.51372549, 0.08627451,
0. , 0. , 0. ],
[0. , 0. , 0. , 0.18039216, 0.80784314,
0.98431373, 0.99607843, 0.99607843, 0.99607843, 0.99607843,
0.99607843, 0.99607843, 0.98039216, 0.97647059, 0.8627451 ,
0.50196078, 0.19607843, 0.08235294, 0.01176471, 0.07843137,
0.63921569, 0.80784314, 0.49411765, 0.1254902 , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0.02745098, 0.30588235,
0.66666667, 0.8627451 , 0.96078431, 0.98039216, 0.97647059,
0.91372549, 0.8 , 0.54901961, 0.49019608, 0.30196078,
0.03921569, 0. , 0. , 0. , 0.0745098 ,
0.49019608, 0.22352941, 0.03529412, 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0.01176471,
0.08235294, 0.19607843, 0.44705882, 0.49803922, 0.49019608,
0.32156863, 0.13333333, 0.03529412, 0.01568627, 0.00784314,
0. , 0. , 0. , 0. , 0.00392157,
0.05490196, 0.01960784, 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.01568627, 0.01568627, 0.01568627,
0.00784314, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ]])
可以看到每个样本是两端相对稀疏中间相对密集的二维列表。大家觉得这是个什么样的分类任务呢?我们可以一起来猜一下。
1.2. 训练标签集
dfl.label.shape
(1000,)
from matplotlib import pyplot as plt
plt.hist(dfl.label)
(array([560., 243., 22., 26., 22., 26., 22., 26., 22., 31.]),
array([ 0. , 6.1, 12.2, 18.3, 24.4, 30.5, 36.6, 42.7, 48.8, 54.9, 61. ]),
<a list of 10 Patch objects>)
可以看到在维度为1000的训练标签集中,一半以上的标签在0到6.1之间,剩下的一半左右分布在6.1到12.2,再剩下的则在12.2到61之间近似均匀分布
2. 将数据导入虚拟租户
上周我们其实已经设置了虚拟租户数量“num_clients = 62”。虚拟租户,即联邦学习的虚拟参与者。
建立租户的数据列表,将数据导入到虚拟租户的数据列表,每个租户的数据作为一个元素
#创建数据列表
client_data = []
#对62个虚拟租户进行循环
for i in range(num_clients):
#创建一个自定义的数据结构m_data
m_data = m_Data()
# 通过我们创建的文件数据加载类将第i个虚拟租户的训练数据(包括样本和标签)加载进来
dfl.load_file_binary(data_filename=train_sample_filename + str(i) + filename_sx,
label_filename=train_label_filename + str(i) + filename_sx)
samples, labels = dfl.getDataToTorch()
#我们用的是免费的试用装,是CPU版本,所以直接将数据以CPU版本格式导入到m_data
m_data.train_samples = samples.reshape(input_dim)
m_data.train_labels = labels.squeeze()
# 以同样的方式将第i个虚拟租户数据加载到m_data
dfl.load_file_binary(data_filename=val_sample_filename + str(i) + filename_sx,
label_filename=val_label_filename + str(i) + filename_sx)
samples, labels = dfl.getDataToTorch()
m_data.val_samples = samples.reshape(input_dim)
m_data.val_labels = labels.squeeze()
#把第i个租户的数据加入到创建的数据列表,作为其中的一个元素
client_data.append(m_data)
3. 建立两层卷积和两层全连接的网络
#继承torch.nn.Module类创建一个CNN类(torch.nn.Module是torch框架下构建神经网络的类)
class CNN(torch.nn.Module):
#自定义构造函数
def __init__(self, n_channel0=1, n_channel1=32, n_channel2=64,
kernel_size=5, stride=1, padding=2,
n_fc_input=3136, n_fc1=512, n_output=62):
#首先调用父类的结构函数
super(CNN, self).__init__()
#创建一个torch.nn的二维卷积,输入通道数in_channels=1,输出通道数out_channels=32,
##卷积核的尺寸kernel_size=5,步长stride=1,控制zero-padding的数目padding=2
##torch.nn.Conv2d有一个groups参数,默认为1,此时为全连接
self.conv1 = torch.nn.Conv2d(in_channels=n_channel0, out_channels=n_channel1,
kernel_size=kernel_size, stride=stride, padding=padding)
#创建第二个二维卷积,输入通道数=上个卷积的输出通道数,输出通道数=64,卷积核尺寸、步长和zero-padding数目与上个卷积保持一致
self.conv2 = torch.nn.Conv2d(in_channels=n_channel1, out_channels=n_channel2,
kernel_size=kernel_size, stride=stride, padding=padding)
#设置第一个全链接层,输入二维张量的维度为3136,输出二维张量的维度为512,对输入的最后一维进行线性变换
self.fc1 = torch.nn.Linear(n_fc_input, n_fc1)
self.fc2 = torch.nn.Linear(n_fc1, n_output)
#自定义forward函数,具体内容将在调用时结合例子进行理解
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(x.shape[0], -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
CNN()
CNN(
(conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(fc1): Linear(in_features=3136, out_features=512, bias=True)
(fc2): Linear(in_features=512, out_features=62, bias=True)
)
4. 设置联邦学习虚拟租户及FedAMP参数
4.1. 设置环境变量
os.environ['MOX_FEDERATED_SIZE'] = str(num_clients)
os.environ['MOX_FEDERATED_BACKEND'] = 'obs'
os.environ['MOX_FEDERATED_URL'] = '/tmp/fed_workspace/'
不太清楚是否必须采用这种环境变量的形式,是否可以用一般的全局变量,希望在下文能看到答案
4.2. 定义stepsize
stepsize = 0.001 #local training stepsize
具体用途同样需要下文再来观察
4.3. FedAMP参数设置
alg_config = fed_algorithm.FedAMPConfig(alpha=0.1, mul=50.0, display=True)
原文解说为:alpha--定义为FedAMP算法中本地模型的权重参数;mul--注意力参数,数值越大注意力增大,为零时将类似于经典联邦学习算法FedAvg;display--是否展示模型相似性及注意力权重。具体的功效请容我们后面再来探索
4.4. 初始化几个变量
client_models = []
client_feds = []
optimizer_list = []
变量的用途容后再议...
4.5. 创建一个两层卷积和两层全连接神经网络的实例
base_model = CNN()
4.6. 设置联邦学习的虚拟租户
尽管4.里留了一堆的问题,但4.6.需要很大的篇幅,并且很多问题也无法在4.6.里获得解答。受限于时间,本期学习就到这里。我们更加深入地对数据进行了探索,将数据导入了虚拟租户,建立了一个两层卷积和两层全连接的神经网络类,并且留下很多悬念(笑脸)。让我们下期再见。。。
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)