《深度学习:主流框架和编程实战》——2.3.3 ResNet程序实现

举报
华章计算机 发表于 2019/06/04 19:52:10 2019/06/04
【摘要】 本书摘自《深度学习:主流框架和编程实战》——书中第2章,第2.3.3节,作者是赵涓涓、强彦。

2.3.3 ResNet程序实现

接下来就正式进入ResNet的实例指导,本实例通过ResNet来解决对Cifar-10数据集的分类。

(1)预要求

在运行前需要预安装Pandas库、numpy库、OpenCV、TensorFlow(1.0.0+)。

(2)文件组织结构

该实例文件夹中的文件组织结构如图2-9所示。其中cifar10_input.py包括下载、提取和预处理Cifar-10图像的函数。resnet.py定义了ResNet结构。cifar10_train.py负责训练和验证。cifar10_test.py负责测试图像。hyper_parameters.py定义了关于训练ResNet网络结构、数据变量的超参数。cifar10_main.py为程序执行的起始文件,包括执行训练和测试两个部分,可以通过执行此文件开始程序。cifar10_data文件夹中存放运行程序所需的数据集,logs_test文件夹中存放验证日志程序,包括程序运行后产生的训练误差、训练损失、验证误差、验证损失生成的test.csv统计表。testdata文件夹中存放测试图像。

image.png

图2-9 文件组织结构

(3)数据集

对于数据集来说,本章采用的是Cifar-10数据集,其中Cifar-10数据集包含有6万张32×32彩色——图像,由Alex Krizhevsky、Vinod Nair和Geoffrey Hinton收集而来,包含50 000张训练图片、10 000张测试图片,用户也可以在其官方网站(http://www.cs.toronto.edu/~kriz/cifar.html)进行下载。其中,训练批次中包含来自每个类的5000张图像。用户可以通过本节的程序来对Cifar-10数据集进行分类,最后查看其训练误差与验证损失。

(4)超参数

hyper_parameters.py文件定义了所有的超参数,用户可以自己定制训练参数。用户可以使用python cifar10_train.py --hyper_parameter1=value1 --hyper_parameter2=value2来设置所有的超参数,也可以改变Python脚本中的默认值。共有以下五类超参数。

1)关于保存训练日志、tensorboard和屏幕输出的超参数:

version(str):检查点和输出时间会被存放在logs_version中。

report_freq(int):在训练过程中,每隔report-freq次进行一次验证集验证,并输出验证结果。

train_ema_decay(float):tensorboard将记录训练误差的移动平均值。这个衰减因子在TensorFlow中用tf.train.ExponentialMovingAverage(FLAGS.train_ema_decay,global_step)来定义一个ExponentialMovingAverage。

2)关于训练过程的超参数:

train_steps(int):训练总步骤。

is_full_validation(boolean):如果用户想使用所有的10 000张验证图像进行验证,则参数设为True,或者想随机使用一批验证数据,则参数设为False。

train_batch_size(int):训练批次大小。

validation_batch_size(int):验证批次大小(is_full_validation=False才会有效)。

init_lr(float):初始化的学习率。根据下列设置,学习率可能会衰减。

lr_decay_factor(float):学习率衰减因子。学习率在每次衰减时会变成lr_decay_factor*current_learning_rate。

decay_step0(int):在decay_step0中,学习率会第一次衰减。

decay_step1(int):学习率的第二次衰减。

3)关于控制网络的超参数:

num_residual_blocks(int):ResNet总层数= 6×num_residual_blocks+2。

weight_decay(float):权重衰减用来正则化网络,total_loss = train_loss + weight_

decay×?权重的平方和。

4)关于数据变量的超参数:

padding_size(int):padding_size是在图像的每一侧加上填充的行(列)数。填充和随机裁剪可以防止过拟合问题。

5)关于加载检查点的超参数:

ckpt_path(str):用户想载入的检查点路径。

is_use_chpt(boolean):如果is_use_chpt=True,可以使用检查点,继续从检查点执行训练。

(5)训练

Train()定义了所有关于训练阶段的类,主要观点是运行train_op FLAGS.train_steps次。如果步数%FLAGS.report_freq == 0,则会立即验证、训练并在tensorboard上写下所有的总结。

(6)测试

Train()类中的test()函数会帮助用户预测,它会返回一个模型[num_test_images, num_labels]的softmax概率。用户需要准备和预处理测试数据,并将它传到函数中。用户既可以使用自己的检查点,也可以使用预训练的ResNet-110检查点。


【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。