联邦学习算法之一ModelArts “pytorch_fedamp_emnist_classification”学习(二)

举报
darkpard 发表于 2021/03/18 19:28:27 2021/03/18
【摘要】 更加深入地对数据进行了探索,将数据导入了虚拟租户,建立了一个两层卷积和两层全连接的神经网络类,并且留下很多悬念(笑脸)

上周,我们已经在ModelAts上实现了pytorch_fedamp_emnist_classification的环境配置,对样本数据结构以及pytorch_fedamp_emnist_classification需要的数据结构进行了简单地探索。让我们开始这周的学习。。。

1. 数据探索

先让我们对数据进行一些更加详细地探索。

1.1. 训练样本集

dfl.load_file_binary(data_filename=train_sample_filename + str(1) + filename_sx,
                         label_filename=train_label_filename + str(1) + filename_sx)
dfl.data.shape
(1000, 28, 28)

可以看到1个样本集中共有1000个样本,每个样本是一个28*28的二维列表,那么每个样本是怎么样的呢?

dfl.data[1]  #可以多看几个样本
array([[0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.00392157, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.01960784, 0.12941176,
        0.26666667, 0.11764706, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.17647059, 0.74509804,
        0.87843137, 0.43921569, 0.01568627, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.01568627, 0.45098039, 0.96078431,
        0.95294118, 0.44313725, 0.01568627, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.08627451, 0.6745098 , 0.97647059,
        0.68627451, 0.13333333, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.00784314, 0.32156863, 0.90980392, 0.91372549,
        0.32941176, 0.01176471, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.01568627, 0.49019608, 0.97647059, 0.85098039,
        0.15294118, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.01568627, 0.49803922, 0.98039216, 0.85098039,
        0.14509804, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.00784314, 0.01568627, 0.01568627, 0.01568627, 0.00784314,
        0.        , 0.        , 0.        ],
       [0.        , 0.01568627, 0.49803922, 0.98039216, 0.85098039,
        0.14509804, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.01568627, 0.07843137, 0.10588235,
        0.30588235, 0.49019608, 0.49803922, 0.49019608, 0.30196078,
        0.03137255, 0.        , 0.        ],
       [0.        , 0.01568627, 0.49803922, 0.98039216, 0.85098039,
        0.14509804, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.01176471, 0.13333333, 0.45098039, 0.66666667, 0.72156863,
        0.87058824, 0.97647059, 0.98039216, 0.97647059, 0.85490196,
        0.30196078, 0.00784314, 0.        ],
       [0.        , 0.01568627, 0.44705882, 0.96078431, 0.85098039,
        0.15294118, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.03137255, 0.13333333,
        0.32941176, 0.68627451, 0.95294118, 0.98823529, 0.99215686,
        0.99607843, 0.99607843, 1.        , 0.99607843, 0.97254902,
        0.48627451, 0.01568627, 0.        ],
       [0.        , 0.        , 0.19607843, 0.86666667, 0.91764706,
        0.37647059, 0.03921569, 0.        , 0.        , 0.        ,
        0.01568627, 0.03529412, 0.18039216, 0.49803922, 0.8       ,
        0.91372549, 0.98431373, 0.99607843, 0.99607843, 0.99215686,
        0.99215686, 0.99607843, 0.99607843, 0.99607843, 0.8627451 ,
        0.30196078, 0.00784314, 0.        ],
       [0.        , 0.        , 0.1254902 , 0.79607843, 0.98823529,
        0.86666667, 0.50196078, 0.2       , 0.14509804, 0.2       ,
        0.44705882, 0.54901961, 0.81568627, 0.96470588, 0.99607843,
        0.99607843, 0.99607843, 0.98431373, 0.87058824, 0.72156863,
        0.74901961, 0.93333333, 0.99607843, 0.90980392, 0.48235294,
        0.03921569, 0.        , 0.        ],
       [0.        , 0.        , 0.03137255, 0.49803922, 0.96078431,
        0.99607843, 0.96470588, 0.87058824, 0.85098039, 0.87058824,
        0.96078431, 0.98039216, 0.99607843, 0.99607843, 0.99607843,
        0.96470588, 0.8627451 , 0.66666667, 0.30980392, 0.11764706,
        0.40784314, 0.91764706, 0.94901961, 0.51372549, 0.08627451,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.18039216, 0.80784314,
        0.98431373, 0.99607843, 0.99607843, 0.99607843, 0.99607843,
        0.99607843, 0.99607843, 0.98039216, 0.97647059, 0.8627451 ,
        0.50196078, 0.19607843, 0.08235294, 0.01176471, 0.07843137,
        0.63921569, 0.80784314, 0.49411765, 0.1254902 , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.02745098, 0.30588235,
        0.66666667, 0.8627451 , 0.96078431, 0.98039216, 0.97647059,
        0.91372549, 0.8       , 0.54901961, 0.49019608, 0.30196078,
        0.03921569, 0.        , 0.        , 0.        , 0.0745098 ,
        0.49019608, 0.22352941, 0.03529412, 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.01176471,
        0.08235294, 0.19607843, 0.44705882, 0.49803922, 0.49019608,
        0.32156863, 0.13333333, 0.03529412, 0.01568627, 0.00784314,
        0.        , 0.        , 0.        , 0.        , 0.00392157,
        0.05490196, 0.01960784, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.01568627, 0.01568627, 0.01568627,
        0.00784314, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ]])

可以看到每个样本是两端相对稀疏中间相对密集的二维列表。大家觉得这是个什么样的分类任务呢?我们可以一起来猜一下。

1.2. 训练标签集

dfl.label.shape
(1000,)
from matplotlib import pyplot as plt 
plt.hist(dfl.label)
(array([560., 243.,  22.,  26.,  22.,  26.,  22.,  26.,  22.,  31.]),
array([ 0. ,  6.1, 12.2, 18.3, 24.4, 30.5, 36.6, 42.7, 48.8, 54.9, 61. ]),
<a list of 10 Patch objects>)

可以看到在维度为1000的训练标签集中,一半以上的标签在06.1之间,剩下的一半左右分布在6.112.2,再剩下的则在12.261之间近似均匀分布

2. 将数据导入虚拟租户

上周我们其实已经设置了虚拟租户数量“num_clients = 62”。虚拟租户,即联邦学习的虚拟参与者。

建立租户的数据列表,将数据导入到虚拟租户的数据列表,每个租户的数据作为一个元素

#创建数据列表
client_data = []
#62个虚拟租户进行循环
for i in range(num_clients):
    #创建一个自定义的数据结构m_data
    m_data = m_Data()
    # 通过我们创建的文件数据加载类将第i个虚拟租户的训练数据(包括样本和标签)加载进来
    dfl.load_file_binary(data_filename=train_sample_filename + str(i) + filename_sx,
                         label_filename=train_label_filename + str(i) + filename_sx)
    samples, labels = dfl.getDataToTorch()
    #我们用的是免费的试用装,是CPU版本,所以直接将数据以CPU版本格式导入到m_data
    m_data.train_samples = samples.reshape(input_dim)
    m_data.train_labels = labels.squeeze()

    # 以同样的方式将第i个虚拟租户数据加载到m_data
    dfl.load_file_binary(data_filename=val_sample_filename + str(i) + filename_sx,
                         label_filename=val_label_filename + str(i) + filename_sx)
    samples, labels = dfl.getDataToTorch()
    m_data.val_samples = samples.reshape(input_dim)
    m_data.val_labels = labels.squeeze()
    #把第i个租户的数据加入到创建的数据列表,作为其中的一个元素
    client_data.append(m_data)

3. 建立两层卷积和两层全连接的网络

#继承torch.nn.Module类创建一个CNN(torch.nn.Moduletorch框架下构建神经网络的类)
class CNN(torch.nn.Module):
    #自定义构造函数
    def __init__(self, n_channel0=1, n_channel1=32, n_channel2=64,
                 kernel_size=5, stride=1, padding=2,
                 n_fc_input=3136, n_fc1=512, n_output=62):
        #首先调用父类的结构函数
        super(CNN, self).__init__()
        #创建一个torch.nn的二维卷积,输入通道数in_channels=1,输出通道数out_channels=32
            ##卷积核的尺寸kernel_size=5,步长stride=1,控制zero-padding的数目padding=2
            ##torch.nn.Conv2d有一个groups参数,默认为1,此时为全连接
        self.conv1 = torch.nn.Conv2d(in_channels=n_channel0, out_channels=n_channel1,
                               kernel_size=kernel_size, stride=stride, padding=padding)
        #创建第二个二维卷积,输入通道数=上个卷积的输出通道数,输出通道数=64,卷积核尺寸、步长和zero-padding数目与上个卷积保持一致
        self.conv2 = torch.nn.Conv2d(in_channels=n_channel1, out_channels=n_channel2,
                               kernel_size=kernel_size, stride=stride, padding=padding)
        #设置第一个全链接层,输入二维张量的维度为3136,输出二维张量的维度为512,对输入的最后一维进行线性变换
        self.fc1 = torch.nn.Linear(n_fc_input, n_fc1)
        self.fc2 = torch.nn.Linear(n_fc1, n_output)
 
    #自定义forward函数,具体内容将在调用时结合例子进行理解
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
   
CNN()
CNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (fc1): Linear(in_features=3136, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=62, bias=True)
)

4. 设置联邦学习虚拟租户及FedAMP参数

4.1. 设置环境变量

os.environ['MOX_FEDERATED_SIZE'] = str(num_clients)
os.environ['MOX_FEDERATED_BACKEND'] = 'obs'
os.environ['MOX_FEDERATED_URL'] = '/tmp/fed_workspace/'

不太清楚是否必须采用这种环境变量的形式,是否可以用一般的全局变量,希望在下文能看到答案

4.2. 定义stepsize

stepsize = 0.001 #local training stepsize

具体用途同样需要下文再来观察

4.3. FedAMP参数设置

alg_config = fed_algorithm.FedAMPConfig(alpha=0.1, mul=50.0, display=True)

原文解说为:alpha--定义为FedAMP算法中本地模型的权重参数;mul--注意力参数,数值越大注意力增大,为零时将类似于经典联邦学习算法FedAvgdisplay--是否展示模型相似性及注意力权重。具体的功效请容我们后面再来探索

4.4. 初始化几个变量

client_models = []
client_feds = []
optimizer_list = []

变量的用途容后再议...

4.5. 创建一个两层卷积和两层全连接神经网络的实例

base_model = CNN()

4.6. 设置联邦学习的虚拟租户

尽管4.里留了一堆的问题,但4.6.需要很大的篇幅,并且很多问题也无法在4.6.里获得解答。受限于时间,本期学习就到这里。我们更加深入地对数据进行了探索,将数据导入了虚拟租户,建立了一个两层卷积和两层全连接的神经网络类,并且留下很多悬念(笑脸)。让我们下期再见。。。

 

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。