快速搭建一个手写数字识别的神经网络

举报
开飞机的大象 发表于 2018/11/24 10:06:33 2018/11/24
【摘要】 学习机器学习的理论知识难免会觉得枯燥乏味,不妨可以先快速实现一个简单的神经网络。让一部分网络先跑起来,最后掌握理论知识。

学习机器学习的理论知识难免会觉得枯燥乏味,不妨可以先快速实现一个简单的神经网络。让一部分网络先跑起来,最后掌握理论知识。

这里我们选择的是手写数字mnist数据集,首先,我们导入数据集

from keras.datasets import mnist
import matplotlib.pyplot as plt

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
print(train_images.shape)


这里可以看到mnist数据集的规模

mnist已经将数据集分为训练集与测试集,训练集有60000张图片,测试集有10000张图片,每张图片由28*28的矩阵组成,我们先对一张图像的矩阵进行可视化。

digit = test_images[0]
plt.imshow(digit, cmap=plt.cm.binary)
plt.show()

image-11.png

test_images[0]数据可视化

为了加深理解,我们把0-9都选择一个样本进行可视化

num = 0
for i in range(len(train_images)):
    if train_labels[i] == num and num < 10:
        num += 1
        plt.subplot(3,4,num)
        plt.axis('off')
        plt.tight_layout()
        plt.imshow(train_images[i], cmap='gray', interpolation='none')
        plt.title("Class {}".format(train_labels[i]))

image-12.png

这里展示了mnist数据集的0-9

mnist数据集导入成功了,接下来,我们使用Keras快速搭建神经网络。

from keras import models
from keras import layers

network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
network.add(layers.Dense(10, activation='softmax'))
network.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

Keras非常适合快速搭建网络,models.Sequential() 表示我们要把每一个数据处理层关联起来,layers.Dense(…)就是构造一个数据处理层。input_shape(28*28,)表示当前处理层接收的数据格式必须是长和宽都是28的二维数组。

image.png

上面的代码搭建了一个这样的网络,可以看出来这个网络有两层

到这里为止,我们的网络已经搭建好了。我们需要将数据输入到这个网络里,下面我们对数据集进行预处理。将长宽都为28的二维矩阵变为28*28的一维矩阵,由于每个像素的取值范围是0-255不便于计算,我们对数据进行了简单归一化。

train_images = train_images.reshape((60000, 28*28))
train_images = train_images.astype('float32') / 255

test_images = test_images.reshape((10000, 28*28))
test_images = test_images.astype('float32') / 25

图片的label我们不是很方便直接使用,我们将label转换为独热编码格式。

from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
print("before change:" ,test_labels[0])
print("after change: ", test_labels[0])

image-13.png

这里很容易可以看到数据转换前后的对比

数据预处理完毕,我们将数据输入网络进行训练。

network.fit(train_images, train_labels, epochs=10, batch_size = 128)

上面的代码中,train_images是用于训练的手写数字图片,train_labels对应的是图片的标记,batch_size 的意思是,每次网络从输入的图片数组中随机选取128个作为一组进行计算,每次计算的循环是10次。

image-14.png

训练过程

可以看到进过训练后,acc越来越接近1,最后的network得到了一个class对象。我们使用测试集队训练的结果进行测试,验证模型的准确性。

test_loss, test_acc = network.evaluate(test_images, test_labels, verbose=1)
print(test_loss) 
print('test_acc', test_acc)

训练的network对测试集进行预测时,正确率达到了98.27%

最后,我们输入一张图片,直观验证一下模型的识别效果。

from keras.datasets import mnist
i= 4
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
digit = test_images[i]
plt.imshow(digit, cmap=plt.cm.binary)
plt.show()
test_images = test_images.reshape((10000, 28*28))
res = network.predict(test_images)
print(res[i])
for k in range(res[i].shape[0]):
    if (res[i][k] == 1):
        print("the number for the picture is : ", k)
        break

识别效果

我们将识别的第5张图片显示出来,通过肉眼判断它应该是数字4,神经网络识别后给出的结果也是数字4,可见网络经过训练后,具备了手写数字图像识别的能力。

====================================================================

本文发表在李思原博客“机器在学习”

原文链接:http://www.siyuanblog.com/?p=1087

欢迎扫码关注我的微信公众号:聚数为塔

qrcode_for_gh_b8391fc7ce11_430.jpg

=====================================================================


【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。