【小白学习keras教程】九、keras 使用GPU和Callbacks模型保存

举报
毛利 发表于 2021/07/16 20:57:58 2021/07/16
【摘要】 @Author:Runsen GPU在gpu上训练使训练神经网络比在cpu上运行快得多Keras支持使用Tensorflow和Theano后端对gpu进行培训文档: https://keras.io/getting-started/faq/#how-can-i-run-keras-on-gpu 安装GPU首先,下载并安装CUDA&CuDNN(假设您使用的是NVIDIA gpu)安装url...

@Author:Runsen

GPU

  • 在gpu上训练使训练神经网络比在cpu上运行快得多

  • Keras支持使用Tensorflow和Theano后端对gpu进行培训

文档: https://keras.io/getting-started/faq/#how-can-i-run-keras-on-gpu

安装GPU

  • 首先,下载并安装CUDA&CuDNN(假设您使用的是NVIDIA gpu)

  • 安装url: https://developer.nvidia.com/cudnn

  • 然后,通过在cmd或terminal中键入以下命令来安装tensorflow gpu(启用gpu的tensorflow版本)

  • pip install tensorflow gpu

  • 然后检查机器正在使用GPU设备

  • 在下面例子中,我有一个GPU设备(其名称为“/device:GPU:0”)

  • 如果使用的是Google Colab,只需将运行时类型更改为“GPU”

import tensorflow as tf
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

Keras Callbacks

Keras Callbacks 为模型训练过程提供了有用的工具

  • ModelCheckpoint
  • Earlystopping
  • ReduceLROnPlateau
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import *
from tensorflow.keras.layers import *
data = load_digits()
X_data = data.images
y_data = data.target
X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, test_size = 0.3, random_state = 777)
# reshaping X data => flatten into 1-dimensional
X_train = X_train.reshape((X_train.shape[0], -1))
X_test = X_test.reshape((X_test.shape[0], -1))
# converting y data into categorical (one-hot encoding)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

1. ModelCheckpoint

  • ModelCheckpoint用于对训练结果进行“检查点”建模
  • 通常,它仅用于保存最佳模型
def create_model():
    model = Sequential()
    model.add(Dense(100, input_shape = (X_train.shape[1],)))
    model.add(Activation('relu'))
    model.add(Dense(100))
    model.add(Activation('relu'))
    model.add(Dense(y_train.shape[1]))
    model.add(Activation('sigmoid'))
    
    model.compile(optimizer = 'Adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])
    return model
    
model = create_model()

Creating callbacks list

  • ModelCheckpoint实例存储在列表中,并在训练时传递
callbacks = [ModelCheckpoint(filepath = 'saved_model.hdf5', monitor='val_acc', verbose=1, mode='max')]
model.fit(X_train, y_train, epochs = 10, batch_size = 500, callbacks = callbacks, validation_data = (X_test, y_test))

Loading saved weights

  • 保存的重量可以加载和使用,而无需进一步训练
  • 当训练时间很长并且模型需要多次重用时,这一点尤其有用
another_model = create_model()
another_model.load_weights('saved_model.hdf5')
another_model.compile(optimizer = 'Adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

Selecting best model

  • 可以使用ModelCheckpoint选择整个epoch中的最佳模型
  • 将**‘save_best_only’**参数设置为True
  • 通常,验证准确度(val acc)被监控并用作最佳模型的标准
callbacks = [ModelCheckpoint(filepath = 'best_model.hdf5', monitor='val_accuracy', verbose=1, save_best_only = True, mode='max')]
model = create_model()
model.fit(X_train, y_train, epochs = 10, batch_size = 500, callbacks = callbacks, validation_data = (X_test, y_test))
best_model = create_model()
best_model.load_weights('best_model.hdf5')
best_model.compile(optimizer = 'Adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])
results = best_model.evaluate(X_test, y_test)
print('Accuracy: ', results[1])

Accuracy: 0.9740740656852722

2.Early stopping

  • 我们可以设置patience参数,它表示模型在没有任何改进的情况下将停下训练
callbacks = [EarlyStopping(monitor = 'accuracy', patience = 1)]
model = create_model()
model.fit(X_train, y_train, epochs = 20, batch_size = 500, callbacks = callbacks, validation_data = (X_test, y_test))

3.Reduce learning rate

  • 一般来说,随着训练的进行,降低学习率(学习率衰减)是可取的

callbacks = [ReduceLROnPlateau(monitor = 'val_loss', factor = 0.5, patience = 5)]
model = create_model()
model.fit(X_train, y_train, epochs = 20, batch_size = 500, callbacks = callbacks, validation_data = (X_test, y_test))
results = model.evaluate(X_test, y_test)
print('Accuracy: ', results[1])

Accuracy: 0.949999988079071

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。