TensorFlow2.0以上版本的图像分类

举报
AI浩 发表于 2021/12/23 01:00:48 2021/12/23
【摘要】 目录 摘要 网络详解 训练部分 1、导入依赖 2、设置全局参数 3、加载数据 4、定义模型 5、切割训练集和验证集 6、数据增强 7、设置callback函数 8、训练并保存模型 9、保存训练历史数据 完整代码: 测试部分 1、导入依赖 2、设置全局参数 3、加载模型 4、处理图片 5、预测类别 ...

目录

摘要

网络详解

训练部分

1、导入依赖

2、设置全局参数

3、加载数据

4、定义模型

5、切割训练集和验证集

6、数据增强

7、设置callback函数

8、训练并保存模型

9、保存训练历史数据

完整代码:

测试部分

1、导入依赖

2、设置全局参数

3、加载模型

4、处理图片

5、预测类别

完整代码


摘要

本篇文章采用CNN实现图像的分类,图像选取了猫狗大战数据集的1万张图像(猫狗各5千)。模型采用自定义的CNN网络,版本是TensorFlow 2.0以上的版本。通过本篇文章,你可以学到图像分类常用的手段,包括:

1、图像增强

2、训练集和验证集切分

3、使用ModelCheckpoint保存最优模型

4、使用ReduceLROnPlateau调整学习率。

5、打印loss结果生成jpg图片。

网络详解

训练部分

1、导入依赖

import os
import numpy as np
from tensorflow import keras
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout,BatchNormalization,Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D,GlobalAveragePooling2D
import cv2
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.model_selection import train_test_split
from tensorflow.python.keras import Input
from tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.python.keras.layers import PReLU, Activation
from tensorflow.python.keras.models import Model

2、设置全局参数


  
  1. norm_size=100#输入到网络的图像尺寸,单位是像素。
  2. datapath='train'#图片的根目录
  3. EPOCHS =100#训练的epoch个数
  4. INIT_LR = 1e-3#初始学习率
  5. labelList=[]#标签
  6. dicClass={'cat':0,'dog':1}#类别
labelnum=2#类别个数
 
batch_size = 4
 

3、加载数据


  
  1. def loadImageData():
  2.     imageList = []
  3.     listImage=os.listdir(datapath)#获取所有的图像
  4.     for img in listImage:#遍历图像
  5.         labelName=dicClass[img.split('.')[0]]#获取label对应的数字
  6.         print(labelName)
  7.         labelList.append(labelName)
  8.         dataImgPath=os.path.join(datapath,img)
  9.         print(dataImgPath)
  10.         image = cv2.imdecode(np.fromfile(dataImgPath, dtype=np.uint8), -1)
  11.         # load the image, pre-process it, and store it in the data list
  12.         image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
  13.         image = img_to_array(image)
  14.         imageList.append(image)
  15.     imageList = np.array(imageList, dtype="int") / 255.0#归一化图像
  16.     return imageList

  
  1. print("开始加载数据")
  2. imageArr=loadImageData()
  3. labelList = np.array(labelList)
  4. print("加载数据完成")
  5. print(labelList)

4、定义模型


  
  1. def bn_prelu(x):
  2.     x = BatchNormalization(epsilon=1e-5)(x)
  3.     x = PReLU()(x)
  4.     return x
  5. def build_model(out_dims, input_shape=(norm_size, norm_size, 3)):
  6.     inputs_dim = Input(input_shape)
  7.     x = Conv2D(32, (3, 3), strides=(2, 2), padding='same')(inputs_dim)
  8.     x = bn_prelu(x)
  9.     x = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(x)
  10.     x = bn_prelu(x)
  11.     x = MaxPooling2D(pool_size=(2, 2))(x)
  12.     x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
  13.     x = bn_prelu(x)
  14.     x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
  15.     x = bn_prelu(x)
  16.     x = MaxPooling2D(pool_size=(2, 2))(x)
  17.     x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
  18.     x = bn_prelu(x)
  19.     x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
  20.     x = bn_prelu(x)
  21.     x = MaxPooling2D(pool_size=(2, 2))(x)
  22.     x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
  23.     x = bn_prelu(x)
  24.     x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
  25.     x = bn_prelu(x)
  26.     x = GlobalAveragePooling2D()(x)
  27.     dp_1 = Dropout(0.5)(x)
  28.     fc2 = Dense(out_dims)(dp_1)
  29.     fc2 = Activation('softmax')(fc2) #此处注意,为sigmoid函数
  30.     model = Model(inputs=inputs_dim, outputs=fc2)
  31.     return model
  32. model=build_model(labelnum)#生成模型
  33. optimizer = Adam(lr=INIT_LR)#加入优化器,设置优化器的学习率。
  34. model.compile(optimizer =optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

5、切割训练集和验证集

trainX,valX,trainY,valY = train_test_split(imageArr,labelList, test_size=0.3, random_state=42)
 

6、数据增强


  
  1. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  2. train_datagen = ImageDataGenerator(featurewise_center=True,
  3.     featurewise_std_normalization=True,
  4.     rotation_range=20,
  5.     width_shift_range=0.2,
  6.     height_shift_range=0.2,
  7.     horizontal_flip=True)
  8. val_datagen = ImageDataGenerator()     #验证集不做图片增强

  
  1. train_generator = train_datagen.flow(trainX,trainY,batch_size=batch_size,shuffle=True)
  2. val_generator = val_datagen.flow(valX,valY,batch_size=batch_size,shuffle=True)

7、设置callback函数


  
  1. checkpointer = ModelCheckpoint(filepath='weights_best_simple_model.hdf5',
  2.                             monitor='val_accuracy',verbose=1, save_best_only=True, mode='max')
  3. reduce = ReduceLROnPlateau(monitor='val_accuracy',patience=10,
  4.                                             verbose=1,
  5.                                             factor=0.5,
  6.                                             min_lr=1e-6)

8、训练并保存模型


  
  1. history = model.fit_generator(train_generator,
  2.        steps_per_epoch=trainX.shape[0]/batch_size,
  3.        validation_data = val_generator,
  4.        epochs=EPOCHS,
  5.        validation_steps=valX.shape[0]/batch_size,
  6.        callbacks=[checkpointer,reduce],
  7.        verbose=1,shuffle=True)
  8. model.save('my_model_.h5')

9、保存训练历史数据


  
  1. import os
  2. loss_trend_graph_path = r"WW_loss.jpg"
  3. acc_trend_graph_path = r"WW_acc.jpg"
  4. import matplotlib.pyplot as plt
  5. print("Now,we start drawing the loss and acc trends graph...")
  6. # summarize history for accuracy
  7. fig = plt.figure(1)
  8. plt.plot(history.history["accuracy"])
  9. plt.plot(history.history["val_accuracy"])
  10. plt.title("Model accuracy")
  11. plt.ylabel("accuracy")
  12. plt.xlabel("epoch")
  13. plt.legend(["train", "test"], loc="upper left")
  14. plt.savefig(acc_trend_graph_path)
  15. plt.close(1)
  16. # summarize history for loss
  17. fig = plt.figure(2)
  18. plt.plot(history.history["loss"])
  19. plt.plot(history.history["val_loss"])
  20. plt.title("Model loss")
  21. plt.ylabel("loss")
  22. plt.xlabel("epoch")
  23. plt.legend(["train", "test"], loc="upper left")
  24. plt.savefig(loss_trend_graph_path)
  25. plt.close(2)
  26. print("We are done, everything seems OK...")
  27. # #windows系统设置10关机
  28. os.system("shutdown -s -t 10")

完整代码:


  
  1. import os
  2. import numpy as np
  3. from tensorflow import keras
  4. from tensorflow.keras.optimizers import Adam
  5. from tensorflow.keras.models import Sequential
  6. from tensorflow.keras.layers import Dense, Dropout,BatchNormalization,Flatten
  7. from tensorflow.keras.layers import Conv2D, MaxPooling2D,GlobalAveragePooling2D
  8. import cv2
  9. from tensorflow.keras.preprocessing.image import img_to_array
  10. from sklearn.model_selection import train_test_split
  11. from tensorflow.python.keras import Input
  12. from tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  13. from tensorflow.python.keras.layers import PReLU, Activation
  14. from tensorflow.python.keras.models import Model
  15. norm_size=100
  16. datapath='train'
  17. EPOCHS =100
  18. INIT_LR = 1e-3
  19. labelList=[]
  20. dicClass={'cat':0,'dog':1}
  21. labelnum=2
  22. batch_size = 4
  23. def loadImageData():
  24. imageList = []
  25. listImage=os.listdir(datapath)
  26. for img in listImage:
  27. labelName=dicClass[img.split('.')[0]]
  28. print(labelName)
  29. labelList.append(labelName)
  30. dataImgPath=os.path.join(datapath,img)
  31. print(dataImgPath)
  32. image = cv2.imdecode(np.fromfile(dataImgPath, dtype=np.uint8), -1)
  33. # load the image, pre-process it, and store it in the data list
  34. image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
  35. image = img_to_array(image)
  36. imageList.append(image)
  37. imageList = np.array(imageList, dtype="int") / 255.0
  38. return imageList
  39. print("开始加载数据")
  40. imageArr=loadImageData()
  41. labelList = np.array(labelList)
  42. print("加载数据完成")
  43. print(labelList)
  44. def bn_prelu(x):
  45. x = BatchNormalization(epsilon=1e-5)(x)
  46. x = PReLU()(x)
  47. return x
  48. def build_model(out_dims, input_shape=(norm_size, norm_size, 3)):
  49. inputs_dim = Input(input_shape)
  50. x = Conv2D(32, (3, 3), strides=(2, 2), padding='same')(inputs_dim)
  51. x = bn_prelu(x)
  52. x = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(x)
  53. x = bn_prelu(x)
  54. x = MaxPooling2D(pool_size=(2, 2))(x)
  55. x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
  56. x = bn_prelu(x)
  57. x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
  58. x = bn_prelu(x)
  59. x = MaxPooling2D(pool_size=(2, 2))(x)
  60. x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
  61. x = bn_prelu(x)
  62. x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
  63. x = bn_prelu(x)
  64. x = MaxPooling2D(pool_size=(2, 2))(x)
  65. x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
  66. x = bn_prelu(x)
  67. x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
  68. x = bn_prelu(x)
  69. x = GlobalAveragePooling2D()(x)
  70. dp_1 = Dropout(0.5)(x)
  71. fc2 = Dense(out_dims)(dp_1)
  72. fc2 = Activation('softmax')(fc2) #此处注意,为sigmoid函数
  73. model = Model(inputs=inputs_dim, outputs=fc2)
  74. return model
  75. model=build_model(labelnum)
  76. optimizer = Adam(lr=INIT_LR)
  77. model.compile(optimizer =optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
  78. trainX,valX,trainY,valY = train_test_split(imageArr,labelList, test_size=0.3, random_state=42)
  79. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  80. train_datagen = ImageDataGenerator(featurewise_center=True,
  81. featurewise_std_normalization=True,
  82. rotation_range=20,
  83. width_shift_range=0.2,
  84. height_shift_range=0.2,
  85. horizontal_flip=True)
  86. val_datagen = ImageDataGenerator() #验证集不做图片增强
  87. train_generator = train_datagen.flow(trainX,trainY,batch_size=batch_size,shuffle=True)
  88. val_generator = val_datagen.flow(valX,valY,batch_size=batch_size,shuffle=True)
  89. checkpointer = ModelCheckpoint(filepath='weights_best_simple_model.hdf5',
  90. monitor='val_accuracy',verbose=1, save_best_only=True, mode='max')
  91. reduce = ReduceLROnPlateau(monitor='val_accuracy',patience=10,
  92. verbose=1,
  93. factor=0.5,
  94. min_lr=1e-6)
  95. history = model.fit_generator(train_generator,
  96. steps_per_epoch=trainX.shape[0]/batch_size,
  97. validation_data = val_generator,
  98. epochs=EPOCHS,
  99. validation_steps=valX.shape[0]/batch_size,
  100. callbacks=[checkpointer,reduce],
  101. verbose=1,shuffle=True)
  102. model.save('my_model_.h5')
  103. print(history)
  104. import os
  105. loss_trend_graph_path = r"WW_loss.jpg"
  106. acc_trend_graph_path = r"WW_acc.jpg"
  107. import matplotlib.pyplot as plt
  108. print("Now,we start drawing the loss and acc trends graph...")
  109. # summarize history for accuracy
  110. fig = plt.figure(1)
  111. plt.plot(history.history["accuracy"])
  112. plt.plot(history.history["val_accuracy"])
  113. plt.title("Model accuracy")
  114. plt.ylabel("accuracy")
  115. plt.xlabel("epoch")
  116. plt.legend(["train", "test"], loc="upper left")
  117. plt.savefig(acc_trend_graph_path)
  118. plt.close(1)
  119. # summarize history for loss
  120. fig = plt.figure(2)
  121. plt.plot(history.history["loss"])
  122. plt.plot(history.history["val_loss"])
  123. plt.title("Model loss")
  124. plt.ylabel("loss")
  125. plt.xlabel("epoch")
  126. plt.legend(["train", "test"], loc="upper left")
  127. plt.savefig(loss_trend_graph_path)
  128. plt.close(2)
  129. print("We are done, everything seems OK...")
  130. # #windows系统设置10关机
  131. os.system("shutdown -s -t 10")

测试部分

1、导入依赖

import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from  tensorflow.keras.models import load_model
import time

2、设置全局参数


norm_size=100
imagelist=[]
emotion_labels = {
   
0: 'cat',
   
1: 'dog'
}

3、加载模型


emotion_classifier=load_model("my_model_.h5")
t1=time.time()

4、处理图片


image = cv2.imdecode(np.fromfile('test/8.jpg', dtype=np.uint8), -1)
# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imagelist.append(image)
imageList = np.array(imagelist, dtype="float") / 255.0

5、预测类别


pre=np.argmax(emotion_classifier.predict(imageList))
emotion = emotion_labels[pre]
t2=time.time()
print(emotion)
t3=t2-t1
print
(t3)

完整代码


  
  1. import cv2
  2. import numpy as np
  3. from tensorflow.keras.preprocessing.image import img_to_array
  4. from tensorflow.keras.models import load_model
  5. import time
  6. norm_size=100
  7. imagelist=[]
  8. emotion_labels = {
  9. 0: 'cat',
  10. 1: 'dog'
  11. }
  12. emotion_classifier=load_model("my_model_.h5")
  13. t1=time.time()
  14. image = cv2.imdecode(np.fromfile('test/8.jpg', dtype=np.uint8), -1)
  15. # load the image, pre-process it, and store it in the data list
  16. image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
  17. image = img_to_array(image)
  18. imagelist.append(image)
  19. imageList = np.array(imagelist, dtype="float") / 255.0
  20. pre=np.argmax(emotion_classifier.predict(imageList))
  21. emotion = emotion_labels[pre]
  22. t2=time.time()
  23. print(emotion)
  24. t3=t2-t1
  25. print(t3)

文章来源: wanghao.blog.csdn.net,作者:AI浩,版权归原作者所有,如需转载,请联系作者。

原文链接:wanghao.blog.csdn.net/article/details/106166653

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。