SWA实战:使用SWA进行微调,提高模型的泛化

举报
AI浩 发表于 2022/04/25 21:40:45 2022/04/25
【摘要】 摘要论文链接:https://arxiv.org/abs/1803.05407.pdf官方代码:https://github.com/timgaripov/swa论文翻译:【第32篇】SWA:平均权重导致更广泛的最优和更好的泛化_AI浩的博客-CSDN博客SWA简单来说就是对训练过程中的多个checkpoints进行平均,以提升模型的泛化性能。记训练过程第iii个epoch的checkpo...

摘要

论文链接:https://arxiv.org/abs/1803.05407.pdf

官方代码:https://github.com/timgaripov/swa

论文翻译:【第32篇】SWA:平均权重导致更广泛的最优和更好的泛化_AI浩的博客-CSDN博客

SWA简单来说就是对训练过程中的多个checkpoints进行平均,以提升模型的泛化性能。记训练过程第 i i 个epoch的checkpoint为 w i w_{i} ,一般情况下我们会选择训练过程中最后的一个epoch的模型 w n w_{n} 或者在验证集上效果最好的一个模型 w i w^{*}_{i} 作为最终模型。但SWA一般在最后采用较高的固定学习速率或者周期式学习速率额外训练一段时间,取多个checkpoints的平均值。

pytorch使用举例:

from torch.optim.swa_utils import AveragedModel, SWALR
# 采用SGD优化器
optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)
# 随机权重平均SWA,实现更好的泛化
swa_model = AveragedModel(model).to(device)
# SWA调整学习率
swa_scheduler = SWALR(optimizer, swa_lr=1e-6)
for epoch in range(1, epoch + 1):
    for batch_idx, (data, target) in enumerate(train_loader):   
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        # 在反向传播前要手动将梯度清零
        optimizer.zero_grad()
        output = model(data)
        #计算losss
        loss = train_criterion(output, targets)
        # 反向传播求解梯度
        loss.backward()
        optimizer.step()
        lr = optimizer.state_dict()['param_groups'][0]['lr']   
    swa_model.update_parameters(model)
    swa_scheduler.step()
# 最后更新BN层参数
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
# 保存结果
torch.save(swa_model.state_dict(), "last.pt")

上面的代码展示了SWA的主要代码,实现的步骤:

1、定义SGD优化器。

2、定义SWA。

3、定义SWALR,调整模型的学习率。

4、开始训练,等待训练完成。

5、在每个epoch中更新模型的参数,更新学习率。

6、等待训练完成后,更新BN层的参数。

详细实现过程

环境

pyotrch:1.10

准备

在开始今天的代码前,我们要准备好训练好的模型。然后才能开始今天的代码。

实现过程

定义模型,并将训练好的模型载入,代码如下:

    model_ft = efficientnet_b1(pretrained=True)
    print(model_ft)
    num_ftrs = model_ft.classifier.in_features
    model_ft.classifier = nn.Linear(num_ftrs, classes)
    model_ft.to(DEVICE)
    model_ft = torch.load(model_path)
    print(model_ft)
    fine_epoch = 80
    fine_tune(model_ft, DEVICE, train_loader, test_loader, criterion_train, criterion_val, fine_epoch, mixup_fn,
              use_amp)

定义模型为efficientnet_b1,这里要和训练的模型保持一致。

如果保存的整个模型,则使用torch.load(model_path)载入模型,如果只保存了权重信息,则要使用model_ft=load_state_dict(torch.load(model_path)),载入模型。

然后,设置fine的epoch为80。

接下来,我们一起去看fine_tune函数中的内容。

 # 采用SGD优化器
    optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)
    if use_amp:
        model, optimizer = amp.initialize(model_ft, optimizer, opt_level="O1")  # 这里是“欧一”,不是“零一”

定义优化器为SGD。

如果使用混合精度,则对amp初始化。

 # 随机权重平均SWA,实现更好的泛化
 swa_model = AveragedModel(model).to(device)
 # SWA调整学习率
 swa_scheduler = SWALR(optimizer, swa_lr=1e-6)

初始化SWA。

使用SWALR调整学习率。

接下来循环epoch,这里都是比较通用的逻辑。

 for epoch in range(1, epoch + 1):
        model.train()
        train_loss = 0
        total_num = len(train_loader.dataset)
        print(total_num, len(train_loader))
        for batch_idx, (data, target) in enumerate(train_loader):
            if len(data) % 2 != 0:
                print(len(data))
                data = data[0:len(data) - 1]
                target = target[0:len(target) - 1]
                print(len(data))
            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
            samples, targets = mixup_fn(data, target)
            output = model(samples)
            loss = train_criterion(output, targets)
            optimizer.zero_grad()
            if use_amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print_loss = loss.data.item()
            train_loss += print_loss
            if (batch_idx + 1) % 10 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(
                    epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                           100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))
        swa_model.update_parameters(model)
        swa_scheduler.step()

主要步骤有:

1、计算loss。

2、是否使用amp混合精度,如果使用混合精度则使用scaled_loss反向传播求梯度,否则直接loss反向传播求梯度。

3、 swa_model.update_parameters(model)更新swa_model的参数。

4、 swa_scheduler.step()更新学习率。

等待所有的epoch执行完成后。

torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
torch.save(swa_model.state_dict(), "last.pt")

更新BN层参数。

然后保存模型的权重。注意:这里只能保存模型的权重,不能保存整个模型。

完成之后就可以测试了,执行代码:

import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os
from torchvision.models.mobilenetv3 import mobilenet_v3_large
import torch.nn as nn
from torch.optim.swa_utils import AveragedModel, SWALR
from timm.models.efficientnet import efficientnet_b1
import numpy as np

def show_outputs(output):

    output_sorted = sorted(output, reverse=True)
    top5_str = '-----TOP 5-----\n'
    for i in range(5):
        value = output_sorted[i]
        index = np.where(output == value)
        for j in range(len(index)):
            if (i + j) >= 5:
                break
            if value > 0:
                topi = '{}: {}\n'.format(index[j], value)
            else:
                topi = '-1: 0.0\n'
            top5_str += topi
    print(top5_str)

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = efficientnet_b1(pretrained=True)

num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, 8)
swa_model = AveragedModel(model)
swa_model.load_state_dict(torch.load("last.pt"))
swa_model.to(DEVICE)
swa_model.eval()

path = 'test/'
testList = os.listdir(path)
for file in testList:
    img = Image.open(path + file)
    img = transform_test(img)
    img.unsqueeze_(0)
    img = Variable(img).to(DEVICE)
    out = swa_model(img)
    out = out.data.cpu().numpy()[0]
    print(file)
    show_outputs(out)

这里测试代码和以前的写法没有啥区别,唯一不同的地方:

重新定义模型,然后载入权重。
运行结果:
image-20220425210850314

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。