《深度学习:主流框架和编程实战》——2.3.3 ResNet程序实现
2.3.3 ResNet程序实现
接下来就正式进入ResNet的实例指导,本实例通过ResNet来解决对Cifar-10数据集的分类。
(1)预要求
在运行前需要预安装Pandas库、numpy库、OpenCV、TensorFlow(1.0.0+)。
(2)文件组织结构
该实例文件夹中的文件组织结构如图2-9所示。其中cifar10_input.py包括下载、提取和预处理Cifar-10图像的函数。resnet.py定义了ResNet结构。cifar10_train.py负责训练和验证。cifar10_test.py负责测试图像。hyper_parameters.py定义了关于训练ResNet网络结构、数据变量的超参数。cifar10_main.py为程序执行的起始文件,包括执行训练和测试两个部分,可以通过执行此文件开始程序。cifar10_data文件夹中存放运行程序所需的数据集,logs_test文件夹中存放验证日志程序,包括程序运行后产生的训练误差、训练损失、验证误差、验证损失生成的test.csv统计表。testdata文件夹中存放测试图像。
图2-9 文件组织结构
(3)数据集
对于数据集来说,本章采用的是Cifar-10数据集,其中Cifar-10数据集包含有6万张32×32彩色——图像,由Alex Krizhevsky、Vinod Nair和Geoffrey Hinton收集而来,包含50 000张训练图片、10 000张测试图片,用户也可以在其官方网站(http://www.cs.toronto.edu/~kriz/cifar.html)进行下载。其中,训练批次中包含来自每个类的5000张图像。用户可以通过本节的程序来对Cifar-10数据集进行分类,最后查看其训练误差与验证损失。
(4)超参数
hyper_parameters.py文件定义了所有的超参数,用户可以自己定制训练参数。用户可以使用python cifar10_train.py --hyper_parameter1=value1 --hyper_parameter2=value2来设置所有的超参数,也可以改变Python脚本中的默认值。共有以下五类超参数。
1)关于保存训练日志、tensorboard和屏幕输出的超参数:
version(str):检查点和输出时间会被存放在logs_version中。
report_freq(int):在训练过程中,每隔report-freq次进行一次验证集验证,并输出验证结果。
train_ema_decay(float):tensorboard将记录训练误差的移动平均值。这个衰减因子在TensorFlow中用tf.train.ExponentialMovingAverage(FLAGS.train_ema_decay,global_step)来定义一个ExponentialMovingAverage。
2)关于训练过程的超参数:
train_steps(int):训练总步骤。
is_full_validation(boolean):如果用户想使用所有的10 000张验证图像进行验证,则参数设为True,或者想随机使用一批验证数据,则参数设为False。
train_batch_size(int):训练批次大小。
validation_batch_size(int):验证批次大小(is_full_validation=False才会有效)。
init_lr(float):初始化的学习率。根据下列设置,学习率可能会衰减。
lr_decay_factor(float):学习率衰减因子。学习率在每次衰减时会变成lr_decay_factor*current_learning_rate。
decay_step0(int):在decay_step0中,学习率会第一次衰减。
decay_step1(int):学习率的第二次衰减。
3)关于控制网络的超参数:
num_residual_blocks(int):ResNet总层数= 6×num_residual_blocks+2。
weight_decay(float):权重衰减用来正则化网络,total_loss = train_loss + weight_
decay×?权重的平方和。
4)关于数据变量的超参数:
padding_size(int):padding_size是在图像的每一侧加上填充的行(列)数。填充和随机裁剪可以防止过拟合问题。
5)关于加载检查点的超参数:
ckpt_path(str):用户想载入的检查点路径。
is_use_chpt(boolean):如果is_use_chpt=True,可以使用检查点,继续从检查点执行训练。
(5)训练
Train()定义了所有关于训练阶段的类,主要观点是运行train_op FLAGS.train_steps次。如果步数%FLAGS.report_freq == 0,则会立即验证、训练并在tensorboard上写下所有的总结。
(6)测试
Train()类中的test()函数会帮助用户预测,它会返回一个模型[num_test_images, num_labels]的softmax概率。用户需要准备和预处理测试数据,并将它传到函数中。用户既可以使用自己的检查点,也可以使用预训练的ResNet-110检查点。
- 点赞
- 收藏
- 关注作者
评论(0)