《深度学习:主流框架和编程实战》——2.3.4 详细代码解析(1)

举报
华章计算机 发表于 2019/06/05 16:44:26 2019/06/05
【摘要】 本书摘自《深度学习:主流框架和编程实战》——书中第2章,第2.3.4节,作者是赵涓涓、强彦。

2.3.4 详细代码解析(1)

1)cifar10_main.py文件是该图像分类程序的入口,通过调用自定义的训练函数和测试函数开始训练网络,并在训练完毕后对网络进行测试。

cifar10_main.py

# 导入cifar10_train中所有函数

from cifar10_train import *

# 导入cifar10_test中所有函数

from cifar10_test import *

# 训练网络

# maybe_download_and_extract()

# train_object=Train()

# train_object.train()

# 测试网络进行单张图像分类

test_object=Test(test_dir)

test_object.test()

test_object.disp_k_result(10)

2)cifar10_train.py文件定义了Train类,负责对所有训练和验证的处理,包括对占位符的定义、构建训练验证图函数的定义、训练函数的定义、测试函数的定义、损失函数的定义、误差函数的定义、生成验证集的定义、生成变量训练批次的定义、训练操作的定义、验证操作的定义、全验证集的定义,程序会跨过类执行。

cifar10_train.py

#导入resnet中所有的包

#主要用于统计激活函数,创建变量,构建输出层,构建批次正则化层,定义残差构件,测试图的功能

from resnet import *

#导入datetime包,主要用于记录运行时间

from datetime import datetime

#导入time包,主要用于记录当前时间

import time

# 导入cifar10_input中所有的包

# 主要用于下载提取训练数据,读取单次训练批次,读取所有图像,垂直翻转,白化图像

# 随机裁剪和翻转,准备训练数据,读取验证数据

from cifar10_input import *

# 以pd的形式导入pandas包,主要围绕series和DataFrame对数据结构进行处理

import pandas as pd

# 定义Train类

class Train(object):

初始化函数:总共有五类占位符,分别为训练图像占位符和训练标签占位符、验证图像占位符和验证标签占位符、学习率占位符。

def __init__(self):

    #建立占位符placeholders

    self.placeholders()

    #占位符函数

    def placeholders(self):

        # 定义图像占位符image_placeholder 类型为float32,随机训练

        #训练批次大小,图像高度,图像宽度,图像深度

        self.image_placeholder = tf.placeholder(dtype=tf.float32,

                                        shape=[FLAGS.train_batch_size,

                                                IMG_HEIGHT,

                                                IMG_WIDTH,

                                                IMG_DEPTH])

        #定义标签占位符label_placeholder 类型为int32,随机训练,训练批次大小

        self.label_placeholder = tf.placeholder(dtype=tf.int32,

                                        shape=[FLAGS.train_batch_size])

        #定义验证图像占位符vali_image_placeholder 类型为float32

        #随机验证,验证批次大小,图像高度,图像宽度,图像深度

        self.vali_image_placeholder = tf.placeholder(dtype=tf.float32,

                                        shape=[FLAGS.validation_batch_size,

                                                IMG_HEIGHT,

                                                IMG_WIDTH,

                                                IMG_DEPTH])

        #定义验证标签占位符vali_label_placeholder 类型为int32

        #随机验证,验证批次大小

        self.vali_label_placeholder = tf.placeholder(dtype=tf.int32,

                                        shape=[FLAGS.validation_batch_size])

        #定义学习率占位符lr_placeholder 类型为float32

        self.lr_placeholder = tf.placeholder(dtype=tf.float32, shape=[])

    #建立训练验证图函数,函数会同时建立训练图和验证图

    def build_train_validation_graph(self):

        #定义全局步数 global_step变量值为0,不可训练

        global_step = tf.Variable(0, trainable=False)

        #定义验证步数validation_step变量值为0,不可训练

        validation_step = tf.Variable(0, trainable=False)

        #定义比数logits占位符为图像占位符,残差构件的数目,不可重新使用

        logits = inference(self.image_placeholder,

                                FLAGS.num_residual_blocks,

                                reuse=False)

        #定义验证比数vali_logits占位符为验证图像占位符

        #残差构件的数目,重新使用

        vali_logits = inference(self.vali_image_placeholder,

                                FLAGS.num_residual_blocks, 

                                reuse=True)

        #计算训练损失,由softmax交叉熵和正则化损失组成

        #定义正则化损失regu_losses为tf.GraphKeys.REGULARIZATION_LOSSES

        regu_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)

        #定义损失 loss=logits,占位符为标签占位符

        loss = self.loss(logits, self.label_placeholder)

        #定义全局损失full_loss等于损失加正则化损失

        self.full_loss = tf.add_n([loss] + regu_losses)

        #定义预测predictions等于softmax(logits)

        predictions = tf.nn.softmax(logits)

        #定义训练顶层误差train_top1_error等于预测,占位符为标签占位符,有效

        self.train_top1_error = self.top_k_error(predictions, 

                                                self.label_placeholder,

                                                1)

        #定义验证损失vali_loss为验证比数,占位符为验证标签占位符

        self.vali_loss = self.loss(vali_logits, self.vali_label_placeholder)

        #定义验证预测vali_predictions 为softmax(vali_logits)

        vali_predictions = tf.nn.softmax(vali_logits)

        #定义验证顶层误差vali_top1_error为验证预测,占位符为验证占位符,有效

        self.vali_top1_error = self.top_k_error(vali_predictions, 

                                                self.vali_label_placeholder,

                                                1)

        #定义训练操作,指数平均移动数为全局步骤,全局损失,训练顶层误差

        self.train_op, self.train_ema_op = self.train_operation(global_step,

                                                        self.full_loss,

                                                        self.train_top1_error)

        #定义验证操作val_op为验证步数,验证顶层误差,验证误差

        self.val_op = self.validation_op(validation_step, 

                                            self.vali_top1_error,

                                            self.vali_loss)

训练函数(训练的主函数):第一步,将所有训练图像和验证图像存入内存中;第二步,建立训练图和验证图;第三步,初始化一个存储器来保存检查点,合并所有的结果,以便能通过运行summary_op展示操作。

def train(self):

    # 读取训练数据以及验证数据

    # 调用cifar10_input.py 包中的prepare_train_data函数和read_validation_data函数

    all_data, all_labels = prepare_train_data(padding_size=FLAGS.padding_size)

    vali_data, vali_labels = read_validation_data()

    #建立训练图和验证图

    self.build_train_validation_graph()

    #定义存储器saver为全局变量

    saver = tf.train.Saver(tf.global_variables())

    #定义总结操作summary_op为合并所有操作

    summary_op = tf.summary.merge_all()

    #定义初始化变量init为初始化所有变量

    init = tf.initialize_all_variables()

    #定义一个新的会话sess

    sess = tf.Session()

    #如果is_use_ckpt为真,则从检查点ckpt_path载入网络文件

    if FLAGS.is_use_ckpt is True:

            saver.restore(sess, FLAGS.ckpt_path)

            print 'Restored from checkpoint...'

    #否则重新初始化

    else:

            sess.run(init)

    #定义总结summary_writer 为train_dir, sess.graph

    summary_writer = tf.summary.FileWriter(train_dir, sess.graph)

    #定义列表step_list

    step_list = []

    #定义训练误差列表train_error_list 

    train_error_list = []

    #定义验证误差列表val_error_list

    val_error_list = []

    print 'Start training...'

    print '----------------------------'

    #对xrange进行step次遍历FLAGS.train_steps

    for step in xrange(FLAGS.train_steps):

        #生成训练批次数据、训练批次标签,并定义训练批次大小

        train_batch_data, train_batch_labels =

                        self.generate_augment_train_batch(all_data, 

                                                    all_labels,

                                                    FLAGS.train_batch_size)

        #生成验证批次数据、验证批次标签,并定义验证批次大小

        validation_batch_data, validation_batch_labels = 

                        self.generate_vali_batch(vali_data,

                                                vali_labels,

                                                FLAGS.validation_batch_size)

        #每隔 FLAGS.report_freq次训练将训练结果进行输出

        if step % FLAGS.report_freq == 0:

            #是否进行全局验证

            if FLAGS.is_full_validation is True:

                validation_loss_value, validation_error_value = 

                            self.full_validation(

                                loss=self.vali_loss,

                                top1_error=self.vali_top1_error, 

                                vali_data=vali_data,vali_labels=vali_labels,

                                session=sess,batch_data=train_batch_data, 

                                batch_label=train_batch_labels)

                #定义验证和函数vali_summ 

                vali_summ = tf.Summary()

                #将此次验证集错误率存储在full_validation_error变量中,类型为float

                vali_summ.value.add(tag='full_validation_error',

                        simple_value=validation_error_value.astype(np.float))

                #写入验证和vali_summ,步数为step

                summary_writer.add_summary(vali_summ, step)

                #刷新 summary_writer

                summary_writer.flush()

            else:

                _, validation_error_value, validation_loss_value = 

                    sess.run([self.val_op,self.vali_top1_error,self.vali_loss],

                        {self.image_placeholder: train_batch_data,

                        self.label_placeholder: train_batch_labels,

                        self.vali_image_placeholder:validation_batch_data,

                        self.vali_label_placeholder:validation_batch_labels,

                        self.lr_placeholder:FLAGS.init_lr})

                #更新验证误差列表

                val_error_list.append(validation_error_value)

        #计时开始

        start_time = time.time()

        #网络开始训练 

        _, _, train_loss_value, train_error_value = sess.run([self.train_op,

                                                    self.train_ema_op,

                                                self.full_loss,

                                                self.train_top1_error],

                                                {self.image_placeholder: 

                                                     train_batch_data,

                                                     self.label_placeholder: 

                                                     train_batch_labels,

                                                     self.vali_image_placeholder: 

                                                     validation_batch_data,

                                                     self.vali_label_placeholder:

                                                     validation_batch_labels,

                                                     self.lr_placeholder:

                                                     FLAGS.init_lr})

        #计算训练时间

        duration = time.time() - start_time

        #在指定间隔内输出网络性能检测

        if step % FLAGS.report_freq == 0:

            summary_str = sess.run(summary_op, 

                                    {self.image_placeholder:train_batch_data,

                                    self.label_placeholder: train_batch_labels,

                                    self.vali_image_placeholder: validation_batch_data,

                                    self.vali_label_placeholder: validation_batch_labels,

                                    self.lr_placeholder: FLAGS.init_lr})

            #写入总结字符串summary_str,步数为step

            summary_writer.add_summary(summary_str, step)

            #每步样例数num_examples_per_step为训练批次大小

            num_examples_per_step = FLAGS.train_batch_size

            #单个样例训练时间

            examples_per_sec = num_examples_per_step / duration

            #每批次运行时间

            sec_per_batch = float(duration)

            #定义字符串format_str

            format_str = ('%s: step %d, loss = %.4f (%.1f examples/sec; %.3f ' 'sec/batch)')

            #终端输出format_str

            print format_str % (datetime.now(), 

                                step,

                                train_loss_value,

                                examples_per_sec,

                                sec_per_batch)

            #输出top1训练误差值

            print 'Train top1 error = ', train_error_value

            #输出验证层误差

            print 'Validation top1 error = %.4f' % validation_error_value

            #输出验证损失值

            print 'Validation loss = ', validation_loss_value

            print '----------------------------'

            #更新步数列表

            step_list.append(step)

            #更新训练误差列表

            train_error_list.append(train_error_value)

            #根据FLAGS.decay_step0和FLAGS.decay_step1调整学习率

            if step == FLAGS.decay_step0 or step == FLAGS.decay_step1:

                        #初始化学习率FLAGS.init_lr 等于0.1*FLAGS.init_lr

                        FLAGS.init_lr = 0.1 * FLAGS.init_lr

                        #输出学习率衰减到FLAGS.init_lr

                        print 'Learning rate decayed to ', FLAGS.init_lr

                        #如果步数step%10000等于0或是步数step加1等于训练步数

            if step % 10000 == 0 or (step + 1) == FLAGS.train_steps:

                        #更新检查点路径checkpoint_path

                            checkpoint_path = os.path.join(train_dir, 'model.ckpt')

                        #存储会话

                        saver.save(sess, checkpoint_path, global_step=step)

                        df = pd.DataFrame(data={'step':step_list, 

                                             'train_error':train_error_list,

                                             'validation_error': val_error_list})

                        #将文件以csv格式存储在训练目录中

                        df.to_csv(train_dir + FLAGS.version + '_error.csv')

损失值计算函数:计算给定比数和真实标签的交叉熵,包含参数logits,格式为2D张量;labels,格式为1D张量;返回损失张量。

def loss(self, logits, labels):

    #定义标签labels的类型为int64

    labels = tf.cast(labels, tf.int64)

    #定义交叉熵cross_entropy为具有比数的稀疏softmax交叉熵

    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels, name='cross_entropy_per_example')

    #定义平均交叉熵cross_entropy_mean为降低平均

    cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')

    #返回平均交叉熵

    return cross_entropy_mean

分类top_k误差计算函数:包含参数predictions,格式为2D张量;labels,格式为1D张量;返回误差。

def top_k_error(self, predictions, labels, k):

    #定义批次大小batch_size预测

    batch_size = predictions.get_shape().as_list()[0]

    #定义顶层in_top1 为预测,标签,k=1

    in_top1 = tf.to_float(tf.nn.in_top_k(predictions, labels, k=1))

    #定义正确数num_correct与顶层有关

    num_correct = tf.reduce_sum(in_top1)

    return (batch_size - num_correct) / float(batch_size)

生成验证批次函数:该函数能够随机生成数据批次,而不是整个验证数据,包含参数vali_data,格式为4D张量;vali_label,1D numpy数组,vali_batch_size,整型;返回验证图像及标签。

    def generate_vali_batch(self, vali_data, vali_label, vali_batch_size):

    #偏移为随机从(10000-验证批次大小)中选取

    offset = np.random.choice(10000 - vali_batch_size, 1)[0]

    #定义验证数据批次为偏移加验证批次大小

    vali_data_batch = vali_data[offset:offset+vali_batch_size, ...]

    #定义验证标签批次为偏移加验证批次大小

    vali_label_batch = vali_label[offset:offset+vali_batch_size]

    #返回验证数据批次、验证标签批次

    return vali_data_batch, vali_label_batch

生成训练批次函数:这个函数能够帮助生成训练数据批次,随机裁剪,同时水平翻转并白化它们。包含参数train_data,格式为4D numpy 数组;train_labels,格式为1D numpy数组;train_batch_size,格式为整型;返回批次数据、批次标签。

def generate_augment_train_batch(self, train_data, train_labels, train_batch_size):

    #偏移为随机从(迭代大小–训练批次大小)中选取

    offset = np.random.choice(EPOCH_SIZE - train_batch_size, 1)[0]

    #定义批次数据为偏移加训练批次大小

    batch_data = train_data[offset:offset+train_batch_size, ...]

    #定义批次数据为随机裁剪和翻转

    batch_data = random_crop_and_flip(batch_data, padding_size=FLAGS.padding_size)

    #定义批次数据为白化图像

    batch_data = whitening_image(batch_data)

    #定义批次标签为偏移加训练批次大小

    batch_label = train_labels[offset:offset+FLAGS.train_batch_size]

    #返回批次数据、批次标签

    return batch_data, batch_label

训练操作函数:定义训练操作。包含参数global_step,1D张量;total_loss,1D张量;top1_error,1D张量;返回训练操作、训练ema操作,运行train_ema_op会为tensorboard生成训练误差和训练损失的移动平均数。

def train_operation(self, global_step, total_loss, top1_error):

    #定义标量学习率learning_rate,为学习率占位符lr_placeholder

    tf.summary.scalar('learning_rate', self.lr_placeholder)

    #定义标量训练损失train_loss,为总损失total_loss

    tf.summary.scalar('train_loss', total_loss)

    #定义标量训练顶层误差train_top1_error,为顶层误差top1_error

    tf.summary.scalar('train_top1_error', top1_error)

    #定义ema指数平均移动数为衰减和全局步数的移动平均数

    ema = tf.train.ExponentialMovingAverage(FLAGS.train_ema_decay, global_step)

    #定义训练指数平均移动数与全局损失、顶层误差有关

    train_ema_op = ema.apply([total_loss, top1_error])

    #定义标量平均训练顶层误差train_top1_error_avg,为ema.average(top1_error)

    tf.summary.scalar('train_top1_error_avg', ema.average(top1_error))

    #定义标量平均训练损失train_loss_avg,为ema.average(total_loss)

    tf.summary.scalar('train_loss_avg', ema.average(total_loss))

    #定义opt训练动量优化器,学习率为学习率占位符,动量为0.9

    opt = tf.train.MomentumOptimizer(learning_rate=self.lr_placeholder, momentum=0.9)

    #定义训练操作train_op为最小总损失,全局步数

    train_op = opt.minimize(total_loss, global_step=global_step)

    #返回训练操作、训练指数平均移动数操作

    return train_op, train_ema_op

验证操作函数:定义验证操作。包含参数validation_step,1D张量;top1_error,1D张量;loss,1D张量;返回验证操作。

def validation_op(self, validation_step, top1_error, loss):

    #定义ema指数平均移动数为0.0和验证步数

    ema = tf.train.ExponentialMovingAverage(0.0, validation_step)

    #定义ema2指数平均移动数为0.95和验证步数

    ema2 = tf.train.ExponentialMovingAverage(0.95, validation_step)

    #定义验证操作val_op与验证步骤分配

    val_op = tf.group(validation_step.assign_add(1),

                            ema.apply([top1_error,loss]),

                            ema2.apply([top1_error, loss]))

    #误差验证top1_error_val为误差的平均值

    top1_error_val = ema.average(top1_error)

    #平均误差top1_error_avg为误差的平均值

    top1_error_avg = ema2.average(top1_error)

    #损失值loss_val为损失的平均值

    loss_val = ema.average(loss)

    #平均损失值为损失的平均值

    loss_val_avg = ema2.average(loss)

    #定义标量验证顶层误差val_top1_error,为top1_error_val

    tf.summary.scalar('val_top1_error', top1_error_val)

    #定义标量验证顶层平均误差val_top1_error_avg,为val_top1_error_avg

    tf.summary.scalar('val_top1_error_avg', top1_error_avg)

    #定义标量验证损失val_loss,为val_loss

    tf.summary.scalar('val_loss', loss_val)

    #定义标量验证平均损失val_loss_avg,为val_loss_avg

    tf.summary.scalar('val_loss_avg', loss_val_avg)

    #返回验证操作

    return val_op

全局验证函数:运行在10000张验证图像的验证函数,包含参数loss,1D张量;top1_error,1D张量;vali_data,4D张量;vali_labels,1D张量;batch_data,4D张量;batch_label,1D张量。返回平均损失和误差损失。

def full_validation(self, loss, top1_error, session, vali_data, vali_labels, 

                    batch_data,batch_label):

            #批次大小为10000

            num_batches = 10000 // FLAGS.validation_batch_size

            #随机排序

            order = np.random.choice(10000, num_batches * FLAGS.validation_batch_size)

            #验证数据子集vali_data_subset从order开始

            vali_data_subset = vali_data[order, ...]

            #验证标签子集vali_labels_subset为order的标签

            vali_labels_subset = vali_labels[order]

            #定义损失列表loss_list 

            loss_list = []

            #定义误差列表error_list 

            error_list = []

            #对range进行step次遍历num_batches

            for step in range(num_batches):

                #定义偏移offset为步数step*验证批次大小validation_batch_size

                offset = step * FLAGS.validation_batch_size

                #定义填充字典feed_dict 

                feed_dict = {self.image_placeholder: batch_data,

                            self.label_placeholder:batch_label,

                            self.vali_image_placeholder:

                                      vali_data_subset[offset:offset+FLAGS.validation_batch_size, ...],

                            self.vali_label_placeholder:

                                      vali_labels_subset[offset:offset+FLAGS.validation_batch_size],

                            self.lr_placeholder:FLAGS.init_lr}

                #定义损失值loss_value,顶层误差值为top1_error_value

                loss_value, top1_error_value = session.run([loss, top1_error], 

                                                            feed_dict=feed_dict)

                #更新损失列表loss_list

                loss_list.append(loss_value)

                #更新误差列表error_list

                error_list.append(top1_error_value)

            #返回损失列表的平均值,误差列表的平均值

            return np.mean(loss_list), np.mean(error_list)



【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。