联邦学习算法之一ModelArts “pytorch_fedamp_emnist_classification”学习(三)
【摘要】 上周,我们更加深入地对数据进行了探索,将数据导入了虚拟租户,建立了一个两层卷积和两层全连接的神经网络类,在4.6.设置联邦学习虚拟租户时戛然而止。这周,让我们从设置虚拟租户开始。1. 以第一个虚拟租户为例分析虚拟租户的设置1.1. 设置环境变量os.environ['MOX_FEDERATED_RANK'] = str(i)跟前面的环境变量一样,至少到这里为止,我们并不是很清楚使用环境变量的...
上周,我们更加深入地对数据进行了探索,将数据导入了虚拟租户,建立了一个两层卷积和两层全连接的神经网络类,在4.6.设置联邦学习虚拟租户时戛然而止。这周,让我们从设置虚拟租户开始。
1. 以第一个虚拟租户为例分析虚拟租户的设置
1.1. 设置环境变量
os.environ['MOX_FEDERATED_RANK'] = str(i)
跟前面的环境变量一样,至少到这里为止,我们并不是很清楚使用环境变量的必要性
1.2. 设置联邦后端的参数
backend_config = fed_backend.FederatedBackendOBSConfig(
load_fn=torch_load,
save_fn=torch_save,
suffix='pth')
fed_backend来源于ModelArts moxing框架,from moxing.framework.federated import fed_backend。根据moxing.framework的介绍,"MoXing Framework模块为MoXing提供基础公共组件,例如访问华为云的OBS服务,和具体的AI引擎解耦,在ModelArts支持的所有AI引擎(TensorFlow、MXNet、PyTorch、MindSpore等)下均可以使用"。
可惜我们并没有找到更多关于moxing.framework.federated的说明,更不用说fed_backend。但从内容上,我们可以大概判断,FederatedBackendOBSConfig应该是一个设置联邦学习背后OBS桶参数的类,它采用torch的下载方法,torch的保存方法,文件格式为".pth"。
1.3. 基于第(二)期4.3.里设置的FedAMP参数创建一个联邦算法
fed_alg = fed_algorithm.build_algorithm(alg_config)
1.4. 基于上述联邦算法和1.2.的后端config创建一个后端backend
backend = fed_backend.build_backend(
config=backend_config,
fed_alg=fed_alg)
1.5. 创建租户,并连接
fed_client = client.FederatedHorizontalPTClient(backend=backend)
fed_client.connect()
1.6. 将创建的租户加入联邦租户列表
client_feds.append(fed_client)
1.7. 建立CNN实例
model = CNN()
1.8. 建立优化器
optimizer = torch.optim.Adam(model.parameters(), lr=stepsize)
1.9. 导入初始状态
model.load_state_dict(base_model.state_dict())
<All keys matched successfully>
1.10. 加模型和优化器加入模型列表
client_models.append((model, optimizer))
1.11. 将优化器加入优化器列表
optimizer_list.append(optimizer)
2. 用for循环设置所有虚租户
for i in range(num_clients):
os.environ['MOX_FEDERATED_RANK'] = str(i)
backend_config = fed_backend.FederatedBackendOBSConfig(
load_fn=torch_load,
save_fn=torch_save,
suffix='pth')
# create client
fed_alg = fed_algorithm.build_algorithm(alg_config)
backend = fed_backend.build_backend(
config=backend_config,
fed_alg=fed_alg)
fed_client = client.FederatedHorizontalPTClient(backend=backend)
fed_client.connect()
client_feds.append(fed_client)
model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr=stepsize)
model.load_state_dict(base_model.state_dict())
client_models.append((model, optimizer))
optimizer_list.append(optimizer)
3. 建立服务端模型
3.1. 按与虚拟租户相同的方式创建服务端的联邦算法
fed_alg = fed_algorithm.build_algorithm(alg_config)
3.2. 设置服务的后端参数
fed_be_config = fed_backend.FederatedBackendOBSConfig(
load_fn=torch_load, save_fn=torch_save, suffix='pth', hooks=[TorchFedHook()])
3.3. 创建服务的后端
fed_be = fed_backend.build_backend(fed_be_config, fed_alg)
3.4. 创建联邦服务
server = fed_server.FederatedServerHorizontal(fed_be)
3.5. 开启服务
server.backend.wait_ready()
INFO:root:waiting for all clients ready...
4. 创建评估模块
def evaluation(model, test_samples, test_labels):
#使用不进行梯度计算的方式来减少内存消耗
with torch.no_grad():
#获得model的输出
outputs = model(test_samples)
_, preds = outputs.max(1)
#计算结果与测试标签集相等的比例,即模型的正确率
test_acc = preds.eq(test_labels).sum() / float(len(test_labels))
return test_acc
5. 创建模型训练模块
def train_local_clients(model, optimizer, data, num_epochs):
#导入数据
train_samples = data.train_samples
train_labels = data.train_labels
val_samples = data.val_samples
val_labels = data.val_labels
#计算训练集长度
total_num_samples = len(train_samples)
#定义损失函数
criterion = torch.nn.CrossEntropyLoss()
#对训练次数循环
for x in range(num_epochs):
#生成numpy列表
seq = np.arange(total_num_samples)
#对列表进行随机排列
random.shuffle(seq)
#对训练集长度进行循环
for begin_batch_index in range(0, total_num_samples, batch_size):
#每一批的输入ids
batch_ids = seq[begin_batch_index: min(begin_batch_index + batch_size, total_num_samples)]
#每一批的输入
inputs = train_samples[batch_ids]
#每一批的标签
labels = train_labels[batch_ids]
#优化器初始化
optimizer.zero_grad()
#得到模型输出
outputs = model(inputs)
#定义损失函数
loss = criterion(outputs, labels.type(torch.long))
loss.backward()
optimizer.step()
#调用模型评估
test_acc = evaluation(model, val_samples, val_labels)
return test_acc
这期,我们创建了虚拟租户和服务端,以及服务端的训练函数和评估函数。下期,我们将以官方案例,正式探索个性化的联邦学习,并对学习结果进行可视化。再下期,我们将结合数据对联邦学习过程进行一些回顾和更深入的刨析。到那时,我们对“pytorch_fedamp_emnist_classification”的学习将暂时告一段落。
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)