联邦学习算法之一ModelArts “pytorch_fedamp_emnist_classification”学习(四)
这期,我们将正式实现pytorch_fedamp_emnist_classification上的案例。
1. 参数初始化
1.1. 初始化每个batch的大小
batch_size = 100
1.2. 初始化训练次数
num_epochs = 5
1.3. 初始化租户与服务端的同步次数
communication = 20
1.4. 初始化准确率
val_acc_cloud = []
val_acc_client = []
2. 以第一次同步为例,尝试训练过程
rounds = 0
2.1. 初始化第一次同步时的准确率
val_acc_total = 0
2.2. 学习租户训练集
2.2.1. 以第一个租户为例,探索训练过程
i = 0
2.2.1.1. 从模型列表获取第一个租户的模型和优化器
model, optimizer = client_models[i]
2.2.1.2. 从租户列表获取第一个租户
local_client = client_feds[i]
2.2.1.3. 从租户数据列表获取第一个租户的数据
data = client_data[i]
2.2.1.4. 学习第一个租户的训练集并得到准确率
acc = train_local_clients(model, optimizer, data, num_epochs)
2.2.1.5. 更新所有租户总准确率
val_acc_total += acc
2.2.1.6. 发送优化器参数
local_client.send_model(parameters=optimizer.param_groups[0]['params'])
2.2.2. 对剩下所有租户的训练集进行学习
for i in range(1, num_clients):
print("Training client " + str(i), end='\r')
model, optimizer = client_models[i]
local_client = client_feds[i]
data = client_data[i]
acc = train_local_clients(model, optimizer, data, num_epochs)
val_acc_total += acc
local_client.send_model(parameters=optimizer.param_groups[0]['params'])
Training client 61
2.3. 计算平均每个租户的准确率,并将它添加到租户准确率列表
val_acc_avg = val_acc_total / num_clients *100
val_acc_client.append(val_acc_avg)
2.4. 获取当前的步数和服务器后端背景参数,并上传到服务端
step = list(server.backend.wait_next())[-1]
ctx = server.backend.get_context()
server.backend.aggregate_data(step, ctx.names)
INFO:root:waiting for next federation...
INFO:root:ready to do federation. <step={'0'}> <clients=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]> <algorithms=FedAMP(alpha=0.1, mul=50.0, sim=None)>
2.5. 更新租户端的模型并重新计算准确率
val_acc_total = 0
with torch.no_grad():
for i in range(num_clients):
val_samples = client_data[i].val_samples
val_labels = client_data[i].val_labels
local_client = client_feds[i]
model, optimizer = client_models[i]
new_parameters = local_client.get_model()
local_client.clean_previous()
local_client.next_step()
parameters = optimizer.param_groups[0]['params']
for p, new_p in zip(parameters, new_parameters):
p.copy_(new_p)
acc = evaluation(model, val_samples, val_labels)
val_acc_total += acc
val_acc_avg = val_acc_total / num_clients *100
val_acc_cloud.append(val_acc_avg)
3. 完成剩下的19轮训练
for rounds in range(1, communication):
val_acc_total = 0
# Train model locally
print("===Starting local training===")
for i in range(num_clients):
print("Training client " + str(i), end='\r')
model, optimizer = client_models[i]
local_client = client_feds[i]
data = client_data[i]
acc = train_local_clients(model, optimizer, data, num_epochs)
val_acc_total += acc
local_client.send_model(parameters=optimizer.param_groups[0]['params'])
# server aggregate
val_acc_avg = val_acc_total / num_clients *100
val_acc_client.append(val_acc_avg)
clear_output(wait=True)
print("\n===Local training is completed===")
print("Communication round: {0:1d}".format(rounds+1))
print("===Starting model aggregation===")
step = list(server.backend.wait_next())[-1]
ctx = server.backend.get_context()
server.backend.aggregate_data(step, ctx.names)
print("===Model aggregation is completed===")
print("Local model average validation test accuracy: {0:5.2f}%".format(val_acc_avg))
# Update local model
val_acc_total = 0
with torch.no_grad():
for i in range(num_clients):
val_samples = client_data[i].val_samples
val_labels = client_data[i].val_labels
local_client = client_feds[i]
model, optimizer = client_models[i]
new_parameters = local_client.get_model()
local_client.clean_previous()
local_client.next_step()
parameters = optimizer.param_groups[0]['params']
for p, new_p in zip(parameters, new_parameters):
p.copy_(new_p)
acc = evaluation(model, val_samples, val_labels)
val_acc_total += acc
val_acc_avg = val_acc_total / num_clients *100
val_acc_cloud.append(val_acc_avg)
print("Personalized cloud model average validation accuracy: {0:5.2f}%".format(val_acc_avg))
print("====Training Completed====")
INFO:root:waiting for next federation...
INFO:root:ready to do federation. <step={'19'}> <clients=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]> <algorithms=FedAMP(alpha=0.1, mul=50.0, sim=None)>
===Local training is completed===
Communication round: 20
===Starting model aggregation===
===Model aggregation is completed===
Local model average validation test accuracy: 75.66%
Personalized cloud model average validation accuracy: 81.30%
====Training Completed====
4. 对每一次同步后的准确率以图的形式展示出来
plt.figure(figsize=(8, 6))
x = list(range(communication))
plt.xticks(x[0:-1:2])
acc_client, = plt.plot(x, val_acc_client, color="red")
acc_cloud, = plt.plot(x, val_acc_cloud, color="blue")
plt.legend([acc_client, acc_cloud], ['Client Acc', 'Cloud Acc'], fontsize=20)
plt.xlabel("Communication round", fontsize=18)
plt.ylabel("Mean validation accuracy", fontsize=18)
plt.title("Mean validation accuracy vs. Communication round", fontsize=18)
plt.show()
从图上我们可以看到,第一轮训练时云端准确率与租户端准确率差不多,都只有50%左右,20轮结束后租户端准确率达到75%上以,而云端准确率达到80%以上。
本期学习到此结束,下期我们将结合具体数据对建模过程进行更深入地探索。
- 点赞
- 收藏
- 关注作者
评论(0)