TensorFlow2.0以上版本的图像分类
目录
摘要
本篇文章采用CNN实现图像的分类,图像选取了猫狗大战数据集的1万张图像(猫狗各5千)。模型采用自定义的CNN网络,版本是TensorFlow 2.0以上的版本。通过本篇文章,你可以学到图像分类常用的手段,包括:
1、图像增强
2、训练集和验证集切分
3、使用ModelCheckpoint保存最优模型
4、使用ReduceLROnPlateau调整学习率。
5、打印loss结果生成jpg图片。
网络详解
训练部分
1、导入依赖
import os
import numpy as np
from tensorflow import keras
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout,BatchNormalization,Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D,GlobalAveragePooling2D
import cv2
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.model_selection import train_test_split
from tensorflow.python.keras import Input
from tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.python.keras.layers import PReLU, Activation
from tensorflow.python.keras.models import Model
2、设置全局参数
  
   - 
    
     
    
    
     
      norm_size=100#输入到网络的图像尺寸,单位是像素。
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      datapath='train'#图片的根目录
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      EPOCHS =100#训练的epoch个数
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      INIT_LR = 1e-3#初始学习率
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      labelList=[]#标签
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      dicClass={'cat':0,'dog':1}#类别
     
    
 labelnum=2#类别个数
 batch_size = 4
 3、加载数据
  
   - 
    
     
    
    
     
      def loadImageData():
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          imageList = []
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          listImage=os.listdir(datapath)#获取所有的图像
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          for img in listImage:#遍历图像
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
              labelName=dicClass[img.split('.')[0]]#获取label对应的数字
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
              print(labelName)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
              labelList.append(labelName)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
              dataImgPath=os.path.join(datapath,img)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
              print(dataImgPath)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
              image = cv2.imdecode(np.fromfile(dataImgPath, dtype=np.uint8), -1)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
              # load the image, pre-process it, and store it in the data list
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
              image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
              image = img_to_array(image)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
              imageList.append(image)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          imageList = np.array(imageList, dtype="int") / 255.0#归一化图像
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          return imageList
     
    
 
  
   - 
    
     
    
    
     
      print("开始加载数据")
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      imageArr=loadImageData()
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      labelList = np.array(labelList)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      print("加载数据完成")
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      print(labelList)
     
    
 4、定义模型
  
   - 
    
     
    
    
     
      def bn_prelu(x):
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = BatchNormalization(epsilon=1e-5)(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = PReLU()(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          return x
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      def build_model(out_dims, input_shape=(norm_size, norm_size, 3)):
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          inputs_dim = Input(input_shape)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = Conv2D(32, (3, 3), strides=(2, 2), padding='same')(inputs_dim)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = MaxPooling2D(pool_size=(2, 2))(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = MaxPooling2D(pool_size=(2, 2))(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = MaxPooling2D(pool_size=(2, 2))(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          x = GlobalAveragePooling2D()(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          dp_1 = Dropout(0.5)(x)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          fc2 = Dense(out_dims)(dp_1)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          fc2 = Activation('softmax')(fc2) #此处注意,为sigmoid函数
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          model = Model(inputs=inputs_dim, outputs=fc2)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          return model
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      model=build_model(labelnum)#生成模型
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      optimizer = Adam(lr=INIT_LR)#加入优化器,设置优化器的学习率。
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      model.compile(optimizer =optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
     
    
 5、切割训练集和验证集
trainX,valX,trainY,valY = train_test_split(imageArr,labelList, test_size=0.3, random_state=42)
 6、数据增强
  
   - 
    
     
    
    
     
      from tensorflow.keras.preprocessing.image import ImageDataGenerator
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      train_datagen = ImageDataGenerator(featurewise_center=True,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          featurewise_std_normalization=True,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          rotation_range=20,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          width_shift_range=0.2,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          height_shift_range=0.2,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
          horizontal_flip=True)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      val_datagen = ImageDataGenerator()     #验证集不做图片增强
     
    
 
  
   - 
    
     
    
    
     
      train_generator = train_datagen.flow(trainX,trainY,batch_size=batch_size,shuffle=True)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      val_generator = val_datagen.flow(valX,valY,batch_size=batch_size,shuffle=True)
     
    
 7、设置callback函数
  
   - 
    
     
    
    
     
      checkpointer = ModelCheckpoint(filepath='weights_best_simple_model.hdf5',
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
                                  monitor='val_accuracy',verbose=1, save_best_only=True, mode='max')
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      reduce = ReduceLROnPlateau(monitor='val_accuracy',patience=10,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
                                                  verbose=1,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
                                                  factor=0.5,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
                                                  min_lr=1e-6)
     
    
 8、训练并保存模型
  
   - 
    
     
    
    
     
      history = model.fit_generator(train_generator,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
             steps_per_epoch=trainX.shape[0]/batch_size,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
             validation_data = val_generator,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
             epochs=EPOCHS,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
             validation_steps=valX.shape[0]/batch_size,
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
             callbacks=[checkpointer,reduce],
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
             verbose=1,shuffle=True)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      model.save('my_model_.h5')
     
    
 9、保存训练历史数据
  
   - 
    
     
    
    
     
      import os
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      loss_trend_graph_path = r"WW_loss.jpg"
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      acc_trend_graph_path = r"WW_acc.jpg"
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      import matplotlib.pyplot as plt
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      print("Now,we start drawing the loss and acc trends graph...")
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      # summarize history for accuracy
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      fig = plt.figure(1)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.plot(history.history["accuracy"])
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.plot(history.history["val_accuracy"])
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.title("Model accuracy")
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.ylabel("accuracy")
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.xlabel("epoch")
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.legend(["train", "test"], loc="upper left")
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.savefig(acc_trend_graph_path)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.close(1)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      # summarize history for loss
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      fig = plt.figure(2)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.plot(history.history["loss"])
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.plot(history.history["val_loss"])
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.title("Model loss")
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.ylabel("loss")
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.xlabel("epoch")
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.legend(["train", "test"], loc="upper left")
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.savefig(loss_trend_graph_path)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      plt.close(2)
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      print("We are done, everything seems OK...")
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      # #windows系统设置10关机
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      os.system("shutdown -s -t 10")
     
    
 

完整代码:
  
   - 
    
     
    
    
     
      import os
     
    
- 
    
     
    
    
     
      import numpy as np
     
    
- 
    
     
    
    
     
      from tensorflow import keras
     
    
- 
    
     
    
    
     
      from tensorflow.keras.optimizers import Adam
     
    
- 
    
     
    
    
     
      from tensorflow.keras.models import Sequential
     
    
- 
    
     
    
    
     
      from tensorflow.keras.layers import Dense, Dropout,BatchNormalization,Flatten
     
    
- 
    
     
    
    
     
      from tensorflow.keras.layers import Conv2D, MaxPooling2D,GlobalAveragePooling2D
     
    
- 
    
     
    
    
     
      import cv2
     
    
- 
    
     
    
    
     
      from tensorflow.keras.preprocessing.image import img_to_array
     
    
- 
    
     
    
    
     
      from sklearn.model_selection import train_test_split
     
    
- 
    
     
    
    
     
      from tensorflow.python.keras import Input
     
    
- 
    
     
    
    
     
      from tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
     
    
- 
    
     
    
    
     
      from tensorflow.python.keras.layers import PReLU, Activation
     
    
- 
    
     
    
    
     
      from tensorflow.python.keras.models import Model
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      norm_size=100
     
    
- 
    
     
    
    
     
      datapath='train'
     
    
- 
    
     
    
    
     
      EPOCHS =100
     
    
- 
    
     
    
    
     
      INIT_LR = 1e-3
     
    
- 
    
     
    
    
     
      labelList=[]
     
    
- 
    
     
    
    
     
      dicClass={'cat':0,'dog':1}
     
    
- 
    
     
    
    
     
      labelnum=2
     
    
- 
    
     
    
    
     
      batch_size = 4
     
    
- 
    
     
    
    
     
      def loadImageData():
     
    
- 
    
     
    
    
     
          imageList = []
     
    
- 
    
     
    
    
     
          listImage=os.listdir(datapath)
     
    
- 
    
     
    
    
     
          for img in listImage:
     
    
- 
    
     
    
    
     
              labelName=dicClass[img.split('.')[0]]
     
    
- 
    
     
    
    
     
              print(labelName)
     
    
- 
    
     
    
    
     
              labelList.append(labelName)
     
    
- 
    
     
    
    
     
              dataImgPath=os.path.join(datapath,img)
     
    
- 
    
     
    
    
     
              print(dataImgPath)
     
    
- 
    
     
    
    
     
              image = cv2.imdecode(np.fromfile(dataImgPath, dtype=np.uint8), -1)
     
    
- 
    
     
    
    
     
              # load the image, pre-process it, and store it in the data list
     
    
- 
    
     
    
    
     
              image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
     
    
- 
    
     
    
    
     
              image = img_to_array(image)
     
    
- 
    
     
    
    
     
              imageList.append(image)
     
    
- 
    
     
    
    
     
          imageList = np.array(imageList, dtype="int") / 255.0
     
    
- 
    
     
    
    
     
          return imageList
     
    
- 
    
     
    
    
     
      print("开始加载数据")
     
    
- 
    
     
    
    
     
      imageArr=loadImageData()
     
    
- 
    
     
    
    
     
      labelList = np.array(labelList)
     
    
- 
    
     
    
    
     
      print("加载数据完成")
     
    
- 
    
     
    
    
     
      print(labelList)
     
    
- 
    
     
    
    
     
      def bn_prelu(x):
     
    
- 
    
     
    
    
     
          x = BatchNormalization(epsilon=1e-5)(x)
     
    
- 
    
     
    
    
     
          x = PReLU()(x)
     
    
- 
    
     
    
    
     
          return x
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      def build_model(out_dims, input_shape=(norm_size, norm_size, 3)):
     
    
- 
    
     
    
    
     
          inputs_dim = Input(input_shape)
     
    
- 
    
     
    
    
     
          x = Conv2D(32, (3, 3), strides=(2, 2), padding='same')(inputs_dim)
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
     
          x = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
     
          x = MaxPooling2D(pool_size=(2, 2))(x)
     
    
- 
    
     
    
    
     
          x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
     
          x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
     
          x = MaxPooling2D(pool_size=(2, 2))(x)
     
    
- 
    
     
    
    
     
          x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
     
          x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
     
          x = MaxPooling2D(pool_size=(2, 2))(x)
     
    
- 
    
     
    
    
     
          x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
     
          x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
     
    
- 
    
     
    
    
     
          x = bn_prelu(x)
     
    
- 
    
     
    
    
     
          x = GlobalAveragePooling2D()(x)
     
    
- 
    
     
    
    
     
          dp_1 = Dropout(0.5)(x)
     
    
- 
    
     
    
    
     
          fc2 = Dense(out_dims)(dp_1)
     
    
- 
    
     
    
    
     
          fc2 = Activation('softmax')(fc2) #此处注意,为sigmoid函数
     
    
- 
    
     
    
    
     
          model = Model(inputs=inputs_dim, outputs=fc2)
     
    
- 
    
     
    
    
     
          return model
     
    
- 
    
     
    
    
     
      model=build_model(labelnum)
     
    
- 
    
     
    
    
     
      optimizer = Adam(lr=INIT_LR)
     
    
- 
    
     
    
    
     
      model.compile(optimizer =optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
     
    
- 
    
     
    
    
     
      trainX,valX,trainY,valY = train_test_split(imageArr,labelList, test_size=0.3, random_state=42)
     
    
- 
    
     
    
    
     
      from tensorflow.keras.preprocessing.image import ImageDataGenerator
     
    
- 
    
     
    
    
     
      train_datagen = ImageDataGenerator(featurewise_center=True,
     
    
- 
    
     
    
    
     
          featurewise_std_normalization=True,
     
    
- 
    
     
    
    
     
          rotation_range=20,
     
    
- 
    
     
    
    
     
          width_shift_range=0.2,
     
    
- 
    
     
    
    
     
          height_shift_range=0.2,
     
    
- 
    
     
    
    
     
          horizontal_flip=True)
     
    
- 
    
     
    
    
     
      val_datagen = ImageDataGenerator()     #验证集不做图片增强
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      train_generator = train_datagen.flow(trainX,trainY,batch_size=batch_size,shuffle=True)
     
    
- 
    
     
    
    
     
      val_generator = val_datagen.flow(valX,valY,batch_size=batch_size,shuffle=True)
     
    
- 
    
     
    
    
     
      checkpointer = ModelCheckpoint(filepath='weights_best_simple_model.hdf5',
     
    
- 
    
     
    
    
     
                                  monitor='val_accuracy',verbose=1, save_best_only=True, mode='max')
     
    
- 
    
     
    
    
      
     
    
- 
    
     
    
    
     
      reduce = ReduceLROnPlateau(monitor='val_accuracy',patience=10,
     
    
- 
    
     
    
    
     
                                                  verbose=1,
     
    
- 
    
     
    
    
     
                                                  factor=0.5,
     
    
- 
    
     
    
    
     
                                                  min_lr=1e-6)
     
    
- 
    
     
    
    
     
      history = model.fit_generator(train_generator,
     
    
- 
    
     
    
    
     
             steps_per_epoch=trainX.shape[0]/batch_size,
     
    
- 
    
     
    
    
     
             validation_data = val_generator,
     
    
- 
    
     
    
    
     
             epochs=EPOCHS,
     
    
- 
    
     
    
    
     
             validation_steps=valX.shape[0]/batch_size,
     
    
- 
    
     
    
    
     
             callbacks=[checkpointer,reduce],
     
    
- 
    
     
    
    
     
             verbose=1,shuffle=True)
     
    
- 
    
     
    
    
     
      model.save('my_model_.h5')
     
    
- 
    
     
    
    
     
      print(history)
     
    
- 
    
     
    
    
     
      import os
     
    
- 
    
     
    
    
     
      loss_trend_graph_path = r"WW_loss.jpg"
     
    
- 
    
     
    
    
     
      acc_trend_graph_path = r"WW_acc.jpg"
     
    
- 
    
     
    
    
     
      import matplotlib.pyplot as plt
     
    
- 
    
     
    
    
     
      print("Now,we start drawing the loss and acc trends graph...")
     
    
- 
    
     
    
    
     
      # summarize history for accuracy
     
    
- 
    
     
    
    
     
      fig = plt.figure(1)
     
    
- 
    
     
    
    
     
      plt.plot(history.history["accuracy"])
     
    
- 
    
     
    
    
     
      plt.plot(history.history["val_accuracy"])
     
    
- 
    
     
    
    
     
      plt.title("Model accuracy")
     
    
- 
    
     
    
    
     
      plt.ylabel("accuracy")
     
    
- 
    
     
    
    
     
      plt.xlabel("epoch")
     
    
- 
    
     
    
    
     
      plt.legend(["train", "test"], loc="upper left")
     
    
- 
    
     
    
    
     
      plt.savefig(acc_trend_graph_path)
     
    
- 
    
     
    
    
     
      plt.close(1)
     
    
- 
    
     
    
    
     
      # summarize history for loss
     
    
- 
    
     
    
    
     
      fig = plt.figure(2)
     
    
- 
    
     
    
    
     
      plt.plot(history.history["loss"])
     
    
- 
    
     
    
    
     
      plt.plot(history.history["val_loss"])
     
    
- 
    
     
    
    
     
      plt.title("Model loss")
     
    
- 
    
     
    
    
     
      plt.ylabel("loss")
     
    
- 
    
     
    
    
     
      plt.xlabel("epoch")
     
    
- 
    
     
    
    
     
      plt.legend(["train", "test"], loc="upper left")
     
    
- 
    
     
    
    
     
      plt.savefig(loss_trend_graph_path)
     
    
- 
    
     
    
    
     
      plt.close(2)
     
    
- 
    
     
    
    
     
      print("We are done, everything seems OK...")
     
    
- 
    
     
    
    
     
      # #windows系统设置10关机
     
    
- 
    
     
    
    
     
      os.system("shutdown -s -t 10")
     
    
 测试部分
1、导入依赖
import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from  tensorflow.keras.models import load_model
import time
2、设置全局参数
norm_size=100
imagelist=[]
emotion_labels = {
     0: 'cat',
     1: 'dog'
}
3、加载模型
emotion_classifier=load_model("my_model_.h5")
t1=time.time()
4、处理图片
image = cv2.imdecode(np.fromfile('test/8.jpg', dtype=np.uint8), -1)
# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imagelist.append(image)
imageList = np.array(imagelist, dtype="float") / 255.0
5、预测类别
pre=np.argmax(emotion_classifier.predict(imageList))
emotion = emotion_labels[pre]
t2=time.time()
print(emotion)
t3=t2-t1
 print(t3)
完整代码
  
   - 
    
     
    
    
     
      import cv2
     
    
- 
    
     
    
    
     
      import numpy as np
     
    
- 
    
     
    
    
     
      from tensorflow.keras.preprocessing.image import img_to_array
     
    
- 
    
     
    
    
     
      from  tensorflow.keras.models import load_model
     
    
- 
    
     
    
    
     
      import time
     
    
- 
    
     
    
    
     
      norm_size=100
     
    
- 
    
     
    
    
     
      imagelist=[]
     
    
- 
    
     
    
    
     
      emotion_labels = {
     
    
- 
    
     
    
    
     
          0: 'cat',
     
    
- 
    
     
    
    
     
          1: 'dog'
     
    
- 
    
     
    
    
     
      }
     
    
- 
    
     
    
    
     
      emotion_classifier=load_model("my_model_.h5")
     
    
- 
    
     
    
    
     
      t1=time.time()
     
    
- 
    
     
    
    
     
      image = cv2.imdecode(np.fromfile('test/8.jpg', dtype=np.uint8), -1)
     
    
- 
    
     
    
    
     
      # load the image, pre-process it, and store it in the data list
     
    
- 
    
     
    
    
     
      image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
     
    
- 
    
     
    
    
     
      image = img_to_array(image)
     
    
- 
    
     
    
    
     
      imagelist.append(image)
     
    
- 
    
     
    
    
     
      imageList = np.array(imagelist, dtype="float") / 255.0
     
    
- 
    
     
    
    
     
      pre=np.argmax(emotion_classifier.predict(imageList))
     
    
- 
    
     
    
    
     
      emotion = emotion_labels[pre]
     
    
- 
    
     
    
    
     
      t2=time.time()
     
    
- 
    
     
    
    
     
      print(emotion)
     
    
- 
    
     
    
    
     
      t3=t2-t1
     
    
- 
    
     
    
    
     
      print(t3)
     
    
 文章来源: wanghao.blog.csdn.net,作者:AI浩,版权归原作者所有,如需转载,请联系作者。
原文链接:wanghao.blog.csdn.net/article/details/106166653
- 点赞
- 收藏
- 关注作者
 
             
           
评论(0)