《深度学习:主流框架和编程实战》——2.3.4 详细代码解析(2)
2.3.4 详细代码解析(2)
3)cifar10_test.py文件定义了Test类,用来预测图像的类别。其中分别定义了测试函数、获取顶层标签函数、显示结果函数。
cifar10_test.py
#从resnet中导入所有函数
from resnet import *
#从datetime中导入datetime
from datetime import datetime
#导入time包
import time
#从cifar10_input中导入所有函数
from cifar10_input import *
#以pd的形式导入pandas
import pandas as pd
#从PIL中导入Image
from PIL import Image
#导入numpy
import numpy
#导入os
import os
#Test类
class Test(object):
#定义初始化函数
def __init__(self,pathname):
#测试图像路径为路径名称
self.test_image_path=pathname
#测试图像占位符,参数包括占位符类型、测试批次大小、图像信息
self.test_image_placeholder = tf.placeholder(dtype=tf.float32,
shape=[FLAGS.test_batch_size,
IMG_HEIGHT,IMG_WIDTH,IMG_DEPTH])
#测试图像数组
self.test_image_array=[]
#定义测试函数
def test(self):
#测试图像数组
def __init__(self,pathname):
#路径地址为测试路径地址
pathDir = os.listdir(self.test_image_path)
#遍历路径地址
for allDir in pathDir:
#将测试图像数据的地址加入child中
child = os.path.join('%s%s' % (self.test_image_path, allDir))
#测试图像为32×32大小的数组
testimage = numpy.asarray(Image.open(child).resize((32, 32),
Image.ANTIALIAS))
#将测试图像加入测试图像数组中
test_image_array.append(testimage)
#将测试图像转化为数组类型
self.test_image_array = numpy.array(test_image_array)
#统计测试图像数量
num_test_images = len(self.test_image_array)
#统计批次数量
num_batches = num_test_images // FLAGS.test_batch_size
#统计剩余图像
remain_images = num_test_images % FLAGS.test_batch_size
print '%i test batches in total...' %num_batches
#定义比数logits占位符为图像占位符、残差构件的数目、不可重新使用
logits = inference(self.test_image_placeholder,
FLAGS.num_residual_blocks,
reuse=False)
#定义预测predictions等于softmax(logits)
predictions = tf.nn.softmax(logits)
#将训练中所有变量保存到saver中
saver = tf.train.Saver(tf.all_variables())
#定义一个新的会话sess
sess = tf.Session()
#从FLAGS.test_ck pt_path恢复会话
saver.restore(sess, FLAGS.test_ckpt_path)
#输出从FLAGS.test_ckpt_path恢复模型
print 'Model restored from ', FLAGS.test_ckpt_path
#定义预测数组
prediction_array = np.array([]).reshape(-1, NUM_CLASS)
#对range进行step次遍历num_batches
for step in range(num_batches):
#如果步数step能被10整除
if step % 10 == 0:
#输出完成步数step批次
print '%i batches finished!' %step
#偏移为(步数*测试批次大小)
offset = step * FLAGS.test_batch_size
#测试图像更新为[offset:offset+FLAGS.test_batch_size, ...]
test_image_batch =
self.test_image_array[offset:offset+FLAGS.test_batch_size, ...]
#批次预测数组与预测、测试图像占位符、测试图像批次有关
batch_prediction_array = sess.run(predictions,
feed_dict={self.test_image_placeholder:
test_image_batch})
#预测数组连接
? prediction_array = np.concatenate((prediction_array, batch_prediction_array))
#如果残余图像不为0
if remain_images != 0:
self.test_image_placeholder = tf.placeholder(dtype=tf.float32,
shape=[remain_images,
IMG_HEIGHT,
IMG_WIDTH,
IMG_DEPTH])
logits = inference(self.test_image_placeholder,
FLAGS.num_residual_blocks,
reuse=True)
#定义预测predictions等于softmax(logits)
predictions = tf.nn.softmax(logits)
#定义测试图像批次为具有残余图像参数的测试图像数组
test_image_batch = self.test_image_array[-remain_images:, ...]
#批次预测数组与预测、测试图像占位符、测试图像批次有关
batch_prediction_array = sess.run(predictions,
feed_dict={self.test_image_placeholder:
test_image_batch})
#预测数组连接
? prediction_array = np.concatenate((prediction_array, batch_prediction_array))
self.prediction_array=prediction_array
#返回预测数组
return prediction_array
#定义获取顶层标签函数,参数为predict_array
def get_top_1_label(self,predict_array):
#定义数组维度为预测数组长度
array_dim = len(predict_array)
#预测标签初始化,为零
predict_label = 0
#最大预测值为预测数组第一个值
maxpredict = predict_array[0]
#遍历数组维度
for i in range(array_dim):
#如果最大预测值小于当前值
if maxpredict < predict_array[i]:
#更新最大预测值为当前值
maxpredict = predict_array[i]
#预测标签为i
predict_label = i
#返回预测标签
return predict_label
#定义显示top_k结果函数,参数为k
def disp_k_result(self,k):
#类别标签
label_name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
#遍历k
for i in range(k):
#定义临时图片数组,存放测试图像
tempimage = Image.fromarray(self.test_image_array[i, ...])
#图像名为测试图像+i
imagename='test%i.jpg'%i
#存储图像名
tempimage.save(imagename)
#定义临时预测标签
temp_predict_label = self.get_top_1_label(self.prediction_array[i, ...])
#输出图像名加预测值加标签名
print imagename+' predict: '+label_name[temp_predict_label]
4)cifar10_input.py文件自动地下载并提取cifar10数据,读入训练数据、验证数据,并对数据做水平翻转、裁剪、白化等操作。
cifar10_input.py
#导入tarfile包,主要用于压缩、解压tar文件
import tarfile
#从six.moves中导入urllib包,主要用于操作url
from six.moves import urllib
#导入sys包,主要包含Python解释器和与它的环境有关的函数
import sys
#以np的形式导入numpy包,主要用于利用数组表示向量、矩阵数据结构
import numpy as np
#导入cPickle包,主要用于将内存中的对象转换成为文本流
import cPickle
#导入os包,包括各种各样的函数,以实现操作系统的许多功能
import os
#导入cv2包,包含OpenCV主要函数
import cv2
#数据路径
data_dir = 'cifar10_data'
#数据完全路径
full_data_dir = 'cifar10_data/cifar-10-batches-py/data_batch_'
#验证数据路径
vali_dir = 'cifar10_data/cifar-10-batches-py/test_batch'
#数据的统一资源定位符
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
#图像宽度
IMG_WIDTH = 32
#图像高度
IMG_HEIGHT = 32
#图像深度
IMG_DEPTH = 3
#类的数目
NUM_CLASS = 10
#训练随机标签TRAIN_RANDOM_LABEL为假 #想要从训练数据中使用随机标签
TRAIN_RANDOM_LABEL = False # Want to use random label for train data?
#验证随机标签VALI_RANDOM_LABEL为假 #想要从验证集使用随机标签
VALI_RANDOM_LABEL = False # Want to use random label for validation?
#训练批次数NUM_TRAIN_BATCH为5 #想要读入多少文件批次,从0到5
NUM_TRAIN_BATCH = 5# How many batches of files you want to read in, from 0 to 5)
#全批次大小
EPOCH_SIZE = 10000 * NUM_TRAIN_BATCH
#下载提取函数,会自动地下载并提取cifar10数据
def maybe_download_and_extract():
#目标目录dest_directory为data_dir
dest_directory = data_dir
#如果指定目录不存在,则创建目标目录
if not os.path.exists(dest_directory):
#生成目录(dest_directory)
os.makedirs(dest_directory)
#文件名为数据的统一资源定位符以'/'分割
filename = DATA_URL.split('/')[-1]
#文件路径filepath将文件名filename加入目标目录
filepath = os.path.join(dest_directory, filename)
#如果数据文件不存在,则从指定的网址下载数据文件
if not os.path.exists(filepath):
#_progress函数
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%'
% (filename, float(count *block_size)/float(total_size) * 100.0))
#刷新缓冲池,输出指定字符串
sys.stdout.flush()
#从网络地址下载数据文件
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
#文件状态
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
#数据文件解压
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
读取数据函数:训练数据总共包含五个数据批次。验证数据只有一个批次。函数的参数是输入数据的目录地址以及是否生成随机标签,函数的返回值是图像和对应的标签数组,包含参数path和is_random_label,返回训练数据和标签。
def _read_one_batch(path, is_random_label):
#打开指定数据文件
fo = open(path, 'rb')
#读取数据并存储至dicts中
dicts = cPickle.load(fo)
#关闭文件
fo.close()
#获取dicts中的data数据
data = dicts['data']
#获取dicts中的label数据
if is_random_label is False:
#标签label为标签词典数组
label = np.array(dicts['labels'])
else:
#标签labels从low为0到high为10、大小为10000中随机产生整型数
labels = np.random.randint(low=0, high=10, size=10000)
#标签label为labels数组
label = np.array(labels)
#返回数据,标签
return data, label
读入训练数据和验证数据函数:这个函数读入所有训练数据或者验证数据,如果需要随机排序,使用随机函数生成排序并且返回图像。返回训练数据和训练标签。
def read_in_all_images(address_list, shuffle=True, is_random_label = False):
data = np.array([]).reshape([0, IMG_WIDTH * IMG_HEIGHT * IMG_DEPTH])
label = np.array([])
#对address_list进行address 次遍历
for address in address_list:
#输出Reading images from 与地址
print 'Reading images from ' + address
#数据批次batch_data,标签批次batch_label读地址、随机标签
batch_data, batch_label = _read_one_batch(address, is_random_label)
#数据data为data与batch_data的连接
data = np.concatenate((data, batch_data))
#标签label为label与batch_label的连接
label = np.concatenate((label, batch_label))
#数据数目num_data为标签长度
num_data = len(label)
#将数据data重塑为(数据数目,图像高度*图像宽度,图像深度)
data = data.reshape((num_data,
IMG_HEIGHT * IMG_WIDTH,
IMG_DEPTH),
order='F')
#将数据data重塑为(数据数目,图像高度,图像宽度,图像深度)
data = data.reshape((num_data,
IMG_HEIGHT,
IMG_WIDTH,
IMG_DEPTH))
#重新排序
if shuffle is True:
#输出'Shuffling'
print 'Shuffling'
#顺序为数据数目置换检验
order = np.random.permutation(num_data)
#数据为从order开始的data列表
data = data[order, ...]
#标签为从order开始的label列表
label = label[order]
#数据data为float32
data = data.astype(np.float32)
#返回数据,标签
return data, label
水平翻转函数:以50%的概率翻转一张图像,包含参数image,3D张量;axis,0代表垂直翻转,1代表水平翻转;返回翻转之后的3D图像。
def horizontal_flip(image, axis):
#翻转支撑flip_prop为从low为0到high为2的随机整数
flip_prop = np.random.randint(low=0, high=2)
#如果翻转支撑flip_prop为0
if flip_prop == 0:
#图像image沿axis翻转
image = cv2.flip(image, axis)
#返回图像
return image
白化图像函数:将图像白化,参数为image_np,返回白化后的图像。
def whitening_image(image_np):
#对range进行i次遍历
for i in range(len(image_np)):
#平均值mean为image_np列表的平均值
mean = np.mean(image_np[i, ...])
#标准std为image_np中最大值,图像高度*图像宽度*图像深度平方根的倒数
std = np.max([np.std(image_np[i, ...]),
1.0/np.sqrt(IMG_HEIGHT * IMG_WIDTH * IMG_DEPTH)])
#image_np为(image_np-平均值)/std
image_np[i,...] = (image_np[i, ...] - mean) / std
#返回image_np
return image_np
随机裁剪和翻转函数:随机裁剪和随机翻转图像批次。包含参数padding_size,整型;batch_data,4D张量;返回随机裁剪和翻转后的图像。
def random_crop_and_flip(batch_data, padding_size):
cropped_batch = np.zeros(len(batch_data) * IMG_HEIGHT *IMG_WIDTH *IMG_DEPTH)
.reshape(len(batch_data),
IMG_HEIGHT,
IMG_WIDTH,
IMG_DEPTH)
#对range进行i次遍历
for i in range(len(batch_data)):
#x偏置为从low为0到high为2*补丁大小、大小为1的随机整型数
x_offset = np.random.randint(low=0, high=2 * padding_size, size=1)[0]
#y偏置为从low为0到high为2*补丁大小、大小为1的随机整型数
y_offset = np.random.randint(low=0, high=2 * padding_size, size=1)[0]
#裁剪批次cropped_batch为批次数据从i开始
#x偏置为x偏置加图像高度,y偏置为y偏置加图像宽度
? cropped_batch[i, ...] = batch_data[i, ...][x_offset:x_offset+IMG_HEIGHT,
? y_offset:y_offset+IMG_WIDTH, :]
#裁剪批次为水平翻转,图像为裁剪批次,翻转轴为1
cropped_batch[i, ...] = horizontal_flip(image=cropped_batch[i, ...], axis=1)
#返回裁剪批次
return cropped_batch
准备训练数据函数:读取所有训练数据到numpy数组,在图像添加值为0的边框,包含参数padding_size,格式为整型。返回所有训练数据及相关标签。
def prepare_train_data(padding_size):
#定义路径列表path_list
path_list = []
#对range进行i次遍历
for i in range(1, NUM_TRAIN_BATCH+1):
#更新路径列表为数据全目录加字符串i
path_list.append(full_data_dir + str(i))
#读取图像数据data,标签label
data, label = read_in_all_images(path_list,
is_random_label=TRAIN_RANDOM_LABEL)
#设置边框大小
pad_width = ((0, 0), (padding_size, padding_size), (padding_size, padding_size), (0, 0))
#图像数据添加边框
? data = np.pad(data, pad_width=pad_width, mode='constant', constant_values=0)
#返回数据、标签
return data, label
#读取验证集数据函数,取验证数据,同时白化
def read_validation_data():
#验证数组validation_array,验证标签validation_labels
validation_array, validation_labels = read_in_all_images([vali_dir],
is_random_label=VALI_RANDOM_LABEL)
#验证数组validation_array为白化图像validation_array
validation_array = whitening_image(validation_array)
#返回验证数组、验证标签
return validation_array, validation_labels
- 点赞
- 收藏
- 关注作者
评论(0)