【PyTorch】 99%程序员都不知道, 深度学习还能这样玩 ☢️建议手收藏☢️

举报
我是小白呀iamarookie 发表于 2021/08/30 10:49:11 2021/08/30
【摘要】 概述你还在为训练无从下手而苦恼么?你还在为模型训练时间漫长而痛苦么?你还在为模型准确率提升困难在深夜一个人啜泣么?今天教大家一个方法, 使得我们的模型起跑线上直接甩开别人几条街. 隔壁王叔叔都学会了! 迁移学习迁移学习 (Transfer Learning) 是把已学训练好的模型参数用作新训练模型的起始参数. 入住 GitHub经过几天的日夜狂肝, 本人完成了在 GitHub 上的第一个项...

概述

你还在为训练无从下手而苦恼么?
你还在为模型训练时间漫长而痛苦么?
你还在为模型准确率提升困难在深夜一个人啜泣么?

在这里插入图片描述

今天教大家一个方法, 使得我们的模型起跑线上直接甩开别人几条街. 隔壁王叔叔都学会了!

迁移学习

迁移学习 (Transfer Learning) 是把已学训练好的模型参数用作新训练模型的起始参数.

入住 GitHub

经过几天的日夜狂肝, 本人完成了在 GitHub 上的第一个项目. 把迁移学习封装成了一个有手就能用的黑盒模型.

在这里插入图片描述
大家只要替换自己的数据集就可以实现多个可选模型迁移学习并自动保存. 就是两个字简单

项目详解

GitHub 链接
在这里插入图片描述

get_data.py (获取数据)

目前支持 MNIST, Fashion MNIST, CIFAR 10 和 CIFAR 100 数据集.

可以在```get_data.py``下自行替换成自己需要的数据集:
在这里插入图片描述

传入数据的格式为:

data_loader = {"train": train_loader, "valid": test_loader}

get_model (获取模型)

目前支持:

  • resnet18
  • resnet34
  • resnet50
  • resnet101
  • resnet152
  • alexnet
  • squeezenet
  • vgg11
  • vgg13
  • vgg16
  • vgg19

替换模型的方法:

python main.py --model_name "模型名称"

例如, 使用 vgg 13:

python main.py --model_name vgg13

例如, 使用 resnet 152:

python main.py --model_name resnet152

参数详解

必填参数:

  • model_name: 模型名称, 类型为 string
  • num_classes: 输出类别数, 类型为 int (例如 MNIST 是 10 分类, CIFAR 100 是 100 分类)

重要参数:

  • data_name: 数据名称, 类型为 string, 默认为 CIFAR10
  • data_gray: 是否为灰度图, 类型为 boolean, 默认为 False
  • num_epochs: 迭代次数, 类型为 int, 默认为 20
  • batch_size: 一个批次的样本数目, 默认为 512

可选参数 (不建议修改):

  • feature_exact: 是否冻层, 类型为 boolean, 默认为 False
  • use_pretrained: 是否使用预训练权重, 类型为 boolean, 默认为 True
  • pretrained_model_path: 预训练权重, 类型为 string, 默认为 pretrained_model/
  • model_save_path: 模型保存路径, 类型为 string, 默认为 “checkpoint/”
  • visualize: 模型可视化, 类型为 boolean, 默认为 True

使用说明

首先我们需要cd到文件路径, 例如:

cd C:\Users\Windows\Desktop\Project\transfer_learning-main

训练 MNIST

使用 resnet18 训练 MNIST 数据集:

python main.py --data_name MNIST --data_gray True --model_name resnet18 --num_classes 10 --batch_size 512

训练 Fashion MNIST

使用 resnet34 训练 Fashion MNIST 数据集:

python main.py --data_name FashionMNIST --data_gray True --model_name resnet34 --num_classes 10 --batch_size 512

训练 CIFAR 10

使用 resnet50 训练 CIFAR 10 数据集:

python main.py --data_name CIFAR10 --model_name resnet50 --num_classes 10 --batch_size 512

训练 CIFAR 100

使用 resnet152 训练 CIFAR 10 数据集:

python main.py --data_name CIFAR100 --model_name resnet152 --num_classes 100 --batch_size 512

训练自己的数据

python main.py --data_name other --model_name ? --num_classes ? --batch_size ? --epochs ?
【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。