《深度学习:主流框架和编程实战》——2.3.4 详细代码解析(2)

举报
华章计算机 发表于 2019/06/05 16:52:23 2019/06/05
【摘要】 本书摘自《深度学习:主流框架和编程实战》——书中第2章,第2.3.4节,作者是赵涓涓、强彦。

2.3.4 详细代码解析(2)

3)cifar10_test.py文件定义了Test类,用来预测图像的类别。其中分别定义了测试函数、获取顶层标签函数、显示结果函数。

cifar10_test.py

#从resnet中导入所有函数

from resnet import *

#从datetime中导入datetime

from datetime import datetime

#导入time包

import time

#从cifar10_input中导入所有函数

from cifar10_input import *

#以pd的形式导入pandas

import pandas as pd

#从PIL中导入Image

from PIL import Image

#导入numpy

import numpy

#导入os

import os

#Test类

class Test(object):

    #定义初始化函数

    def __init__(self,pathname):

        #测试图像路径为路径名称

        self.test_image_path=pathname

        #测试图像占位符,参数包括占位符类型、测试批次大小、图像信息

        self.test_image_placeholder = tf.placeholder(dtype=tf.float32, 

                                                     shape=[FLAGS.test_batch_size,

        IMG_HEIGHT,IMG_WIDTH,IMG_DEPTH])

        #测试图像数组

        self.test_image_array=[]

    #定义测试函数

    def test(self):

        #测试图像数组

        def __init__(self,pathname):

            #路径地址为测试路径地址

            pathDir = os.listdir(self.test_image_path)

            #遍历路径地址

            for allDir in pathDir:

                #将测试图像数据的地址加入child中

                child = os.path.join('%s%s' % (self.test_image_path, allDir))

                #测试图像为32×32大小的数组

                testimage = numpy.asarray(Image.open(child).resize((32, 32), 

                                         Image.ANTIALIAS))

                #将测试图像加入测试图像数组中

                test_image_array.append(testimage)

            #将测试图像转化为数组类型

            self.test_image_array = numpy.array(test_image_array)

            #统计测试图像数量

            num_test_images = len(self.test_image_array)

            #统计批次数量

            num_batches = num_test_images // FLAGS.test_batch_size

            #统计剩余图像

            remain_images = num_test_images % FLAGS.test_batch_size

            print '%i test batches in total...' %num_batches

            #定义比数logits占位符为图像占位符、残差构件的数目、不可重新使用

            logits = inference(self.test_image_placeholder, 

                                FLAGS.num_residual_blocks,

                                reuse=False)

            #定义预测predictions等于softmax(logits)

            predictions = tf.nn.softmax(logits)

            #将训练中所有变量保存到saver中

            saver = tf.train.Saver(tf.all_variables())

            #定义一个新的会话sess

            sess = tf.Session()

            #从FLAGS.test_ck pt_path恢复会话

            saver.restore(sess, FLAGS.test_ckpt_path)

            #输出从FLAGS.test_ckpt_path恢复模型

            print 'Model restored from ', FLAGS.test_ckpt_path

            #定义预测数组

            prediction_array = np.array([]).reshape(-1, NUM_CLASS)

            #对range进行step次遍历num_batches

            for step in range(num_batches):

            #如果步数step能被10整除

            if step % 10 == 0:

                #输出完成步数step批次

                print '%i batches finished!' %step

                #偏移为(步数*测试批次大小)

                offset = step * FLAGS.test_batch_size

                #测试图像更新为[offset:offset+FLAGS.test_batch_size, ...]

                test_image_batch = 

                         self.test_image_array[offset:offset+FLAGS.test_batch_size, ...]

                #批次预测数组与预测、测试图像占位符、测试图像批次有关

                batch_prediction_array = sess.run(predictions,

                                    feed_dict={self.test_image_placeholder:

                                                        test_image_batch})

                #预测数组连接

                ? prediction_array = np.concatenate((prediction_array, batch_prediction_array))

            #如果残余图像不为0

            if remain_images != 0:

                self.test_image_placeholder = tf.placeholder(dtype=tf.float32,

                                                        shape=[remain_images,

                                                                IMG_HEIGHT,

                                                                IMG_WIDTH,

                                                                IMG_DEPTH])

                logits = inference(self.test_image_placeholder,

                                    FLAGS.num_residual_blocks,

                                    reuse=True)

                #定义预测predictions等于softmax(logits)

                predictions = tf.nn.softmax(logits)

                #定义测试图像批次为具有残余图像参数的测试图像数组

                test_image_batch = self.test_image_array[-remain_images:, ...]

                #批次预测数组与预测、测试图像占位符、测试图像批次有关

                batch_prediction_array = sess.run(predictions,

                                    feed_dict={self.test_image_placeholder: 

                                                            test_image_batch})

                #预测数组连接

             ?    prediction_array = np.concatenate((prediction_array, batch_prediction_array))

            self.prediction_array=prediction_array

            #返回预测数组

            return prediction_array

#定义获取顶层标签函数,参数为predict_array

def get_top_1_label(self,predict_array):

    #定义数组维度为预测数组长度

    array_dim = len(predict_array)

    #预测标签初始化,为零

    predict_label = 0

    #最大预测值为预测数组第一个值

    maxpredict = predict_array[0]

    #遍历数组维度

    for i in range(array_dim):

        #如果最大预测值小于当前值

        if maxpredict < predict_array[i]:

                #更新最大预测值为当前值

                maxpredict = predict_array[i]

                #预测标签为i

                predict_label = i

        #返回预测标签

        return predict_label

    #定义显示top_k结果函数,参数为k

    def disp_k_result(self,k):

        #类别标签

        label_name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

        #遍历k

        for i in range(k):

            #定义临时图片数组,存放测试图像

            tempimage = Image.fromarray(self.test_image_array[i, ...])

            #图像名为测试图像+i

            imagename='test%i.jpg'%i

            #存储图像名

            tempimage.save(imagename)

            #定义临时预测标签

            temp_predict_label = self.get_top_1_label(self.prediction_array[i, ...])

            #输出图像名加预测值加标签名

            print imagename+'  predict:  '+label_name[temp_predict_label]

4)cifar10_input.py文件自动地下载并提取cifar10数据,读入训练数据、验证数据,并对数据做水平翻转、裁剪、白化等操作。

cifar10_input.py

#导入tarfile包,主要用于压缩、解压tar文件

import tarfile

#从six.moves中导入urllib包,主要用于操作url

from six.moves import urllib

#导入sys包,主要包含Python解释器和与它的环境有关的函数

import sys

#以np的形式导入numpy包,主要用于利用数组表示向量、矩阵数据结构

import numpy as np

#导入cPickle包,主要用于将内存中的对象转换成为文本流

import cPickle

#导入os包,包括各种各样的函数,以实现操作系统的许多功能

import os

#导入cv2包,包含OpenCV主要函数

import cv2

#数据路径

data_dir = 'cifar10_data'

#数据完全路径

full_data_dir = 'cifar10_data/cifar-10-batches-py/data_batch_'

#验证数据路径

vali_dir = 'cifar10_data/cifar-10-batches-py/test_batch'

#数据的统一资源定位符

DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'

#图像宽度

IMG_WIDTH = 32

#图像高度

IMG_HEIGHT = 32

#图像深度

IMG_DEPTH = 3

#类的数目

NUM_CLASS = 10

#训练随机标签TRAIN_RANDOM_LABEL为假        #想要从训练数据中使用随机标签

TRAIN_RANDOM_LABEL = False # Want to use random label for train data?

#验证随机标签VALI_RANDOM_LABEL为假        #想要从验证集使用随机标签

VALI_RANDOM_LABEL = False # Want to use random label for validation?

#训练批次数NUM_TRAIN_BATCH为5       #想要读入多少文件批次,从0到5

NUM_TRAIN_BATCH = 5# How many batches of files you want to read in, from 0 to 5)

#全批次大小

EPOCH_SIZE = 10000 * NUM_TRAIN_BATCH

#下载提取函数,会自动地下载并提取cifar10数据

def maybe_download_and_extract():

        #目标目录dest_directory为data_dir

        dest_directory = data_dir

        #如果指定目录不存在,则创建目标目录

        if not os.path.exists(dest_directory):

            #生成目录(dest_directory)

            os.makedirs(dest_directory)

            #文件名为数据的统一资源定位符以'/'分割

            filename = DATA_URL.split('/')[-1]

            #文件路径filepath将文件名filename加入目标目录

            filepath = os.path.join(dest_directory, filename)

            #如果数据文件不存在,则从指定的网址下载数据文件

            if not os.path.exists(filepath):

                #_progress函数

                def _progress(count, block_size, total_size):

                    sys.stdout.write('\r>> Downloading %s %.1f%%'

                        % (filename, float(count *block_size)/float(total_size) * 100.0))

                    #刷新缓冲池,输出指定字符串

                    sys.stdout.flush()

                    #从网络地址下载数据文件

                    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)

                    #文件状态

                    statinfo = os.stat(filepath)

                       print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')

                    #数据文件解压

                    tarfile.open(filepath, 'r:gz').extractall(dest_directory)

读取数据函数:训练数据总共包含五个数据批次。验证数据只有一个批次。函数的参数是输入数据的目录地址以及是否生成随机标签,函数的返回值是图像和对应的标签数组,包含参数path和is_random_label,返回训练数据和标签。

def _read_one_batch(path, is_random_label):

    #打开指定数据文件

    fo = open(path, 'rb')

    #读取数据并存储至dicts中

    dicts = cPickle.load(fo)

    #关闭文件

    fo.close()

    #获取dicts中的data数据

    data = dicts['data']

    #获取dicts中的label数据

    if is_random_label is False:

        #标签label为标签词典数组

        label = np.array(dicts['labels'])

        else:

        #标签labels从low为0到high为10、大小为10000中随机产生整型数

        labels = np.random.randint(low=0, high=10, size=10000)

        #标签label为labels数组

        label = np.array(labels)

    #返回数据,标签

    return data, label

读入训练数据和验证数据函数:这个函数读入所有训练数据或者验证数据,如果需要随机排序,使用随机函数生成排序并且返回图像。返回训练数据和训练标签。

def read_in_all_images(address_list, shuffle=True, is_random_label = False):

    data = np.array([]).reshape([0, IMG_WIDTH * IMG_HEIGHT * IMG_DEPTH])

    label = np.array([])

    #对address_list进行address 次遍历

    for address in address_list:

        #输出Reading images from 与地址

        print 'Reading images from ' + address

        #数据批次batch_data,标签批次batch_label读地址、随机标签

        batch_data, batch_label = _read_one_batch(address, is_random_label)

        #数据data为data与batch_data的连接

        data = np.concatenate((data, batch_data))

        #标签label为label与batch_label的连接

        label = np.concatenate((label, batch_label))

        #数据数目num_data为标签长度

        num_data = len(label)

        #将数据data重塑为(数据数目,图像高度*图像宽度,图像深度)

        data = data.reshape((num_data, 

                                IMG_HEIGHT * IMG_WIDTH, 

                                IMG_DEPTH), 

                            order='F')

        #将数据data重塑为(数据数目,图像高度,图像宽度,图像深度)

        data = data.reshape((num_data,

                                IMG_HEIGHT,

                                IMG_WIDTH,

                                IMG_DEPTH))

        #重新排序

        if shuffle is True:

            #输出'Shuffling'

            print 'Shuffling'

            #顺序为数据数目置换检验

            order = np.random.permutation(num_data)

            #数据为从order开始的data列表

            data = data[order, ...]

            #标签为从order开始的label列表

            label = label[order]

        #数据data为float32

        data = data.astype(np.float32)

        #返回数据,标签

        return data, label

水平翻转函数:以50%的概率翻转一张图像,包含参数image,3D张量;axis,0代表垂直翻转,1代表水平翻转;返回翻转之后的3D图像。

def horizontal_flip(image, axis):

    #翻转支撑flip_prop为从low为0到high为2的随机整数

    flip_prop = np.random.randint(low=0, high=2)

    #如果翻转支撑flip_prop为0

    if flip_prop == 0:

        #图像image沿axis翻转

        image = cv2.flip(image, axis)

        #返回图像

    return image

白化图像函数:将图像白化,参数为image_np,返回白化后的图像。

def whitening_image(image_np):

    #对range进行i次遍历

    for i in range(len(image_np)):

        #平均值mean为image_np列表的平均值

        mean = np.mean(image_np[i, ...])

        #标准std为image_np中最大值,图像高度*图像宽度*图像深度平方根的倒数

        std = np.max([np.std(image_np[i, ...]),

                    1.0/np.sqrt(IMG_HEIGHT * IMG_WIDTH * IMG_DEPTH)])

        #image_np为(image_np-平均值)/std

        image_np[i,...] = (image_np[i, ...] - mean) / std

        #返回image_np

return image_np

随机裁剪和翻转函数:随机裁剪和随机翻转图像批次。包含参数padding_size,整型;batch_data,4D张量;返回随机裁剪和翻转后的图像。

def random_crop_and_flip(batch_data, padding_size):

    cropped_batch = np.zeros(len(batch_data) * IMG_HEIGHT *IMG_WIDTH *IMG_DEPTH)

                                        .reshape(len(batch_data),

                                                            IMG_HEIGHT, 

                                                            IMG_WIDTH,

                                                            IMG_DEPTH)

    #对range进行i次遍历

    for i in range(len(batch_data)):

        #x偏置为从low为0到high为2*补丁大小、大小为1的随机整型数

        x_offset = np.random.randint(low=0, high=2 * padding_size, size=1)[0]

        #y偏置为从low为0到high为2*补丁大小、大小为1的随机整型数

        y_offset = np.random.randint(low=0, high=2 * padding_size, size=1)[0]

        #裁剪批次cropped_batch为批次数据从i开始

        #x偏置为x偏置加图像高度,y偏置为y偏置加图像宽度

         ? cropped_batch[i, ...] = batch_data[i, ...][x_offset:x_offset+IMG_HEIGHT,

                            ?     y_offset:y_offset+IMG_WIDTH, :]

        #裁剪批次为水平翻转,图像为裁剪批次,翻转轴为1

         cropped_batch[i, ...] = horizontal_flip(image=cropped_batch[i, ...], axis=1)

    #返回裁剪批次

    return cropped_batch

准备训练数据函数:读取所有训练数据到numpy数组,在图像添加值为0的边框,包含参数padding_size,格式为整型。返回所有训练数据及相关标签。

def prepare_train_data(padding_size):

    #定义路径列表path_list

    path_list = []

    #对range进行i次遍历

    for i in range(1, NUM_TRAIN_BATCH+1):

        #更新路径列表为数据全目录加字符串i

        path_list.append(full_data_dir + str(i))

        #读取图像数据data,标签label

        data, label = read_in_all_images(path_list,

                                        is_random_label=TRAIN_RANDOM_LABEL)

        #设置边框大小

        pad_width = ((0, 0), (padding_size, padding_size), (padding_size, padding_size), (0, 0))

        #图像数据添加边框

        ? data = np.pad(data, pad_width=pad_width, mode='constant', constant_values=0)

    #返回数据、标签

    return data, label

#读取验证集数据函数,取验证数据,同时白化

def read_validation_data():

    #验证数组validation_array,验证标签validation_labels

    validation_array, validation_labels = read_in_all_images([vali_dir],

                                          is_random_label=VALI_RANDOM_LABEL)

    #验证数组validation_array为白化图像validation_array

    validation_array = whitening_image(validation_array)

    #返回验证数组、验证标签

    return validation_array, validation_labels


【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。