联邦学习算法之一ModelArts “pytorch_fedamp_emnist_classification”学习(四)

举报
静怡酱酱酱酱酱酱酱 发表于 2021/07/16 16:44:50 2021/07/16
【摘要】 这期,我们将正式实现pytorch_fedamp_emnist_classification上的案例。1. 参数初始化1.1. 初始化每个batch的大小batch_size = 1001.2. 初始化训练次数num_epochs = 51.3. 初始化租户与服务端的同步次数communication = 201.4. 初始化准确率val_acc_cloud = []val_acc_clie...

这期,我们将正式实现pytorch_fedamp_emnist_classification上的案例。

1. 参数初始化

1.1. 初始化每个batch的大小

batch_size = 100

1.2. 初始化训练次数

num_epochs = 5

1.3. 初始化租户与服务端的同步次数

communication = 20

1.4. 初始化准确率

val_acc_cloud = []
val_acc_client = []

2. 以第一次同步为例,尝试训练过程

rounds = 0

2.1. 初始化第一次同步时的准确率

val_acc_total = 0

2.2. 学习租户训练集

2.2.1. 以第一个租户为例,探索训练过程

i = 0

2.2.1.1. 从模型列表获取第一个租户的模型和优化器

model, optimizer = client_models[i]

2.2.1.2. 从租户列表获取第一个租户

local_client = client_feds[i]

2.2.1.3. 从租户数据列表获取第一个租户的数据

data = client_data[i]

2.2.1.4. 学习第一个租户的训练集并得到准确率

acc = train_local_clients(model, optimizer, data, num_epochs)

2.2.1.5. 更新所有租户总准确率

val_acc_total += acc

2.2.1.6. 发送优化器参数

local_client.send_model(parameters=optimizer.param_groups[0]['params'])

2.2.2. 对剩下所有租户的训练集进行学习

for i in range(1, num_clients):
    print("Training client " + str(i), end='\r')
    model, optimizer = client_models[i]
    local_client = client_feds[i]
    data = client_data[i]
    acc = train_local_clients(model, optimizer, data, num_epochs)  
    val_acc_total += acc
    local_client.send_model(parameters=optimizer.param_groups[0]['params']) 

Training client 61

2.3. 计算平均每个租户的准确率,并将它添加到租户准确率列表

val_acc_avg = val_acc_total / num_clients *100
val_acc_client.append(val_acc_avg)

2.4. 获取当前的步数和服务器后端背景参数,并上传到服务端

step = list(server.backend.wait_next())[-1]
ctx = server.backend.get_context()
server.backend.aggregate_data(step, ctx.names)

INFO:root:waiting for next federation...

INFO:root:ready to do federation. <step={'0'}> <clients=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]> <algorithms=FedAMP(alpha=0.1, mul=50.0, sim=None)>


2.5. 更新租户端的模型并重新计算准确率

val_acc_total = 0
with torch.no_grad():
    for i in range(num_clients):
        val_samples =   client_data[i].val_samples
        val_labels = client_data[i].val_labels
        local_client = client_feds[i]
        model, optimizer = client_models[i]
        new_parameters = local_client.get_model()
        local_client.clean_previous()
        local_client.next_step()
        parameters = optimizer.param_groups[0]['params']
        for p, new_p in zip(parameters, new_parameters):
            p.copy_(new_p)
        acc = evaluation(model, val_samples, val_labels)
        val_acc_total += acc
val_acc_avg = val_acc_total / num_clients *100
val_acc_cloud.append(val_acc_avg)

3. 完成剩下的19轮训练

for rounds in range(1, communication):
    val_acc_total = 0
    # Train model locally
   
    print("===Starting local training===")
   
    for i in range(num_clients):
        print("Training client " + str(i), end='\r')
        model, optimizer = client_models[i]
        local_client = client_feds[i]
        data = client_data[i]
        acc = train_local_clients(model, optimizer, data, num_epochs)  
        val_acc_total += acc
        local_client.send_model(parameters=optimizer.param_groups[0]['params'])  
   
    # server aggregate 
    val_acc_avg = val_acc_total / num_clients *100
    val_acc_client.append(val_acc_avg)
   
    clear_output(wait=True)
    print("\n===Local training is completed===")
    print("Communication round: {0:1d}".format(rounds+1))
    print("===Starting model aggregation===")
   
    step = list(server.backend.wait_next())[-1]
    ctx = server.backend.get_context()
    server.backend.aggregate_data(step, ctx.names)
   
    print("===Model aggregation is completed===")
   
    print("Local model average validation test accuracy: {0:5.2f}%".format(val_acc_avg))
   
    # Update local model
    val_acc_total = 0
    with torch.no_grad():
        for i in range(num_clients):
            val_samples =   client_data[i].val_samples
            val_labels = client_data[i].val_labels
            local_client = client_feds[i]
            model, optimizer = client_models[i]
            new_parameters = local_client.get_model()
            local_client.clean_previous()
            local_client.next_step()
            parameters = optimizer.param_groups[0]['params']
            for p, new_p in zip(parameters, new_parameters):
                p.copy_(new_p)
            acc = evaluation(model, val_samples, val_labels)
            val_acc_total += acc
    val_acc_avg = val_acc_total / num_clients *100
    val_acc_cloud.append(val_acc_avg)
    print("Personalized cloud model average validation accuracy: {0:5.2f}%".format(val_acc_avg))     
       
print("====Training Completed====")

INFO:root:waiting for next federation...

INFO:root:ready to do federation. <step={'19'}> <clients=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]> <algorithms=FedAMP(alpha=0.1, mul=50.0, sim=None)>

===Local training is completed===

Communication round: 20

===Starting model aggregation===


===Model aggregation is completed===

Local model average validation test accuracy: 75.66%

Personalized cloud model average validation accuracy: 81.30%

====Training Completed====

4. 对每一次同步后的准确率以图的形式展示出来

plt.figure(figsize=(8, 6))
x = list(range(communication))
plt.xticks(x[0:-1:2])
acc_client, = plt.plot(x, val_acc_client, color="red")
acc_cloud, = plt.plot(x, val_acc_cloud, color="blue")
plt.legend([acc_client, acc_cloud], ['Client Acc', 'Cloud Acc'], fontsize=20)
plt.xlabel("Communication round", fontsize=18)
plt.ylabel("Mean validation accuracy", fontsize=18)
plt.title("Mean validation accuracy vs. Communication round", fontsize=18)
plt.show()


从图上我们可以看到,第一轮训练时云端准确率与租户端准确率差不多,都只有50%左右,20轮结束后租户端准确率达到75%上以,而云端准确率达到80%以上。

本期学习到此结束,下期我们将结合具体数据对建模过程进行更深入地探索。

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。