CNN模型识别cifar数据集

举报
Echo_Wish 发表于 2022/07/12 16:54:50 2022/07/12
【摘要】 构建简单的CNN模型识别cifar数据集。经过几天的简单学习,尝试写了一个简单的CNN模型通过cifar数据集进行训练。效果一般,测试集上的的表现并不好,说明模型的构建不怎么样。# -*- coding = utf-8 -*-# @Time : 2020/10/16 16:19# @Author : tcc# @File : cifar_test.py# @Software : pycha...

构建简单的CNN模型识别cifar数据集。

经过几天的简单学习,尝试写了一个简单的CNN模型通过cifar数据集进行训练。效果一般,测试集上的的表现并不好,说明模型的构建不怎么样。

# -*- coding = utf-8 -*-
# @Time : 2020/10/16 16:19
# @Author : tcc
# @File : cifar_test.py
# @Software : pycharm

# 使用cnn模型训练识别cafir数据集


import keras
# 引入数据集
from keras.datasets import cifar10
# 反序列化和序列化
import pickle
# 主要用于获取文件的属性
import os
from keras.preprocessing.image import ImageDataGenerator
# 序列模型
from keras.models import Sequential
# 引入全连接层,dropout 层,flatten 层,展开层
from keras.layers import Dense, Dropout, Activation, Flatten
# 卷积层,池化层
from keras.layers import Conv2D, MaxPooling2D
# 引入numpy矩阵运算
import numpy as np
# 加载模型模块
from keras.models import load_model


# 文件读取,打开本地文件读取数据集数据
def open_file_data():
    pass


# 1.本地加载数据集
def load_dataset_data():
    # 加载训练集50000张32x32的rgb图片,测试集1000032x32的rgb图片
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    return (x_train, y_train), (x_test, y_test)


# 2.归一化(规范化)数据
def standard_data(x, y, x_, y_):
    x = x / 255
    x_ = x_ / 255
    # keras.utils.to_categorical将整型标签转为one_hot。y为int数组,num_classes为标签类别总数,大于max(y)(标签从0开始的)。
    y = keras.utils.to_categorical(y, 10)
    y_ = keras.utils.to_categorical(y_, 10)
    return x, y, x_, y_


# ..图片的可视化
def show_data(x):
    import matplotlib.pyplot as plt
    plt.imshow(x[0])
    plt.show()
    plt.imshow(x[1])
    plt.show()


# 将 RGB 图像转为灰度图
def rgb2gray(img):
    # Y' = 0.299 R + 0.587 G + 0.114 B
    # https://en.wikipedia.org/wiki/Grayscale#Converting_color_to_grayscale
    return np.dot(img[..., :3], [0.299, 0.587, 0.114])


# 3.构建CNN模型
def make_model():
    # 声明序贯模型
    model = Sequential()
    # 卷积层,32个3x3的卷积核,输入为32x32大小,通道数3的图像,边框填充,激活函数relu
    model.add(Conv2D(filters=32, kernel_size=(3, 3), activation='relu', input_shape=(32, 32, 3), padding='same'))
    # 卷积层,32个3x3的卷积核
    model.add(Conv2D(filters=32, kernel_size=(3, 3), activation='relu'))
    # 池化层,shape值除以2
    model.add(MaxPooling2D(pool_size=(2, 2)))
    # dropout层,舍弃0.25的神经元
    model.add(Dropout(0.25))
    # 卷积层
    model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu', padding='same'))
    # 卷积层
    model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu'))
    # 池化层
    model.add(MaxPooling2D(pool_size=(2, 2)))
    # dropout层
    model.add(Dropout(0.25))
    # flatten展开层,将二维三维张量摊平(展开)成一维向量
    model.add(Flatten())
    # 全连接层
    model.add(Dense(512, activation='relu'))
    # dropout层
    model.add(Dropout(0.5))
    # 全连接层
    model.add(Dense(10, activation='softmax'))
    model.summary()
    opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)
    model.compile(loss='categorical_crossentropy',
                  optimizer=opt,
                  metrics=['accuracy'])
    return model


# 训练模型所需的数据进行图像转换
def train_model():
    # 数据增强器,包含了范围20°内的随机旋转,±15%的缩放以及随机的水平翻转
    data_modify = ImageDataGenerator(
        rotation_range=20,
        zoom_range=0.15,
        horizontal_flip=True,
    )
    return data_modify


# main方法
def main():
    (x, y), (x_, y_) = load_dataset_data()
    x, y, x_, y_ = standard_data(x, y, x_, y_)
    # 这里是训练模型
    show_data(x)
    print(x.shape, y.shape)
    model = make_model()
    data = train_model()
    model.fit_generator(data.flow(x, y, batch_size=64), steps_per_epoch=1000, epochs=20, validation_data=(x_, y_))
    model.save('cifar10_trained_model.h5')
    # 下面是加载模型并进行测试
    mnist_model = load_model('cifar10_trained_model.h5')
    scores = mnist_model.evaluate(x_, y_, verbose=1)
    print('Test loss:', scores[0])
    print('Test accuracy:', scores[1])


if __name__ == "__main__":
    main()

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。