【PyTorch】 99%程序员都不知道, 深度学习还能这样玩 ☢️建议手收藏☢️
【摘要】 概述你还在为训练无从下手而苦恼么?你还在为模型训练时间漫长而痛苦么?你还在为模型准确率提升困难在深夜一个人啜泣么?今天教大家一个方法, 使得我们的模型起跑线上直接甩开别人几条街. 隔壁王叔叔都学会了! 迁移学习迁移学习 (Transfer Learning) 是把已学训练好的模型参数用作新训练模型的起始参数. 入住 GitHub经过几天的日夜狂肝, 本人完成了在 GitHub 上的第一个项...
概述
你还在为训练无从下手而苦恼么?
你还在为模型训练时间漫长而痛苦么?
你还在为模型准确率提升困难在深夜一个人啜泣么?
今天教大家一个方法, 使得我们的模型起跑线上直接甩开别人几条街. 隔壁王叔叔都学会了!
迁移学习
迁移学习 (Transfer Learning) 是把已学训练好的模型参数用作新训练模型的起始参数.
入住 GitHub
经过几天的日夜狂肝, 本人完成了在 GitHub 上的第一个项目. 把迁移学习封装成了一个有手就能用的黑盒模型.
大家只要替换自己的数据集就可以实现多个可选模型迁移学习并自动保存. 就是两个字简单
项目详解
get_data.py (获取数据)
目前支持 MNIST, Fashion MNIST, CIFAR 10 和 CIFAR 100 数据集.
可以在```get_data.py``下自行替换成自己需要的数据集:
传入数据的格式为:
data_loader = {"train": train_loader, "valid": test_loader}
get_model (获取模型)
目前支持:
- resnet18
- resnet34
- resnet50
- resnet101
- resnet152
- alexnet
- squeezenet
- vgg11
- vgg13
- vgg16
- vgg19
替换模型的方法:
python main.py --model_name "模型名称"
例如, 使用 vgg 13:
python main.py --model_name vgg13
例如, 使用 resnet 152:
python main.py --model_name resnet152
参数详解
必填参数:
- model_name: 模型名称, 类型为 string
- num_classes: 输出类别数, 类型为 int (例如 MNIST 是 10 分类, CIFAR 100 是 100 分类)
重要参数:
- data_name: 数据名称, 类型为 string, 默认为 CIFAR10
- data_gray: 是否为灰度图, 类型为 boolean, 默认为 False
- num_epochs: 迭代次数, 类型为 int, 默认为 20
- batch_size: 一个批次的样本数目, 默认为 512
可选参数 (不建议修改):
- feature_exact: 是否冻层, 类型为 boolean, 默认为 False
- use_pretrained: 是否使用预训练权重, 类型为 boolean, 默认为 True
- pretrained_model_path: 预训练权重, 类型为 string, 默认为 pretrained_model/
- model_save_path: 模型保存路径, 类型为 string, 默认为 “checkpoint/”
- visualize: 模型可视化, 类型为 boolean, 默认为 True
使用说明
首先我们需要cd
到文件路径, 例如:
cd C:\Users\Windows\Desktop\Project\transfer_learning-main
训练 MNIST
使用 resnet18 训练 MNIST 数据集:
python main.py --data_name MNIST --data_gray True --model_name resnet18 --num_classes 10 --batch_size 512
训练 Fashion MNIST
使用 resnet34 训练 Fashion MNIST 数据集:
python main.py --data_name FashionMNIST --data_gray True --model_name resnet34 --num_classes 10 --batch_size 512
训练 CIFAR 10
使用 resnet50 训练 CIFAR 10 数据集:
python main.py --data_name CIFAR10 --model_name resnet50 --num_classes 10 --batch_size 512
训练 CIFAR 100
使用 resnet152 训练 CIFAR 10 数据集:
python main.py --data_name CIFAR100 --model_name resnet152 --num_classes 100 --batch_size 512
训练自己的数据
python main.py --data_name other --model_name ? --num_classes ? --batch_size ? --epochs ?
【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)