网络优化方法--Dropout
@toc
1、Dropout介绍
Dropout 也是一种用于抵抗过拟合的技术,它试图改变网络本身来对网络进行优化。我 们先来了解一下它的工作机制,当我们训练一个普通的神经网络时,网络的结构可能如图所示。
Dropout 通常是在神经网络隐藏层的部分使用,使用的时候会临时关闭掉一部分的神经 元,我们可以通过一个参数来控制神经元被关闭的概率,网络结构如图所示。
更详细的流程如下:
- 在模型训练阶段我们可以先给 Dropout 参数设置一个值,例如 0.4。意思是 大约 60%的神经元是工作的,大约 40%神经元是不工作的
- 给需要进行Dropout的神经网络层的每一个神经元生成一个0-1 的随机数(一 般是对隐藏层进行 Dropout)。如果神经元的随机数小于 0.6,那么该神经元就设置为 工作状态的;如果神经元的随机数大于等于 0.6,那么该神经元就设置为不工作的,不工作状态的意思就是不参与计算和训练,可以当这个神经元不存在。
- 设置好一部分神经元工作一部分神经元不工作之后,我们会发现神经网络的输 出值会发现变化,如上图,如果隐藏层有一半不工作,那么网络输出值就会比原来的值要小,因为计算 WX+b 时,如果 W 矩阵中,有一部分的值变成 0,那么最后 的计算结果肯定会变小。所以为了使用 Dropout 的网络层神经元信号的总和不会发生 太大的变化,对于工作的神经元的输出信号还需要除以 0.4。
- 训练阶段重复 1-3 步骤,每一次都随机选择部分的神经元参与训练。
- 在测试阶段所有的神经元都参与计算。
Dropout 为什么会起作用呢?这个问题很难通过数学推导来证明。我们在介绍 ReLU 激 活函数的时候有提到过神经网络的信号是冗余的,神经网络在做预测时并不需要隐藏层所有神 经元都工作,只需要一部分隐藏层神经元工作即可。我们可以抽象地来理解 Dropout,当我们 使用 Dropout 的时候,就有点像我们在训练很多不同的结构更简单的神经网络,最后测试阶 段再综合所有的网络结构得到结果。或者另外一种理解方式是我们使用 Dropout 的时候减少 了神经元之间的相互关联,同时强制网络使用更少的特征来做预测,可以增加模型的健壮性。
除了这两种理解方式之外还可以有其他的很多理解方式,深度学习中很多技巧都是不能用 数学推导得到同时又比较难理解的。但重要的是这些技巧在实际应用中可以帮助我们得到更好 的结果。
==Dropout 比较适合应用于只有少量数据但是需要训练复杂模型的场景,这类场景在图像 领域比较常见,所以 Dropout 经常用于图像领域。==
2、Dropout程序
这里我们而将看到一个Dropout在MNIST数据集识别中的应用,我们建立两个模型,一个使用Dropout,另一个不使用Dropout,对比两个模型的收敛速度。
代码在Jupyter Notebook中调试。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Dropout,Flatten
from tensorflow.keras.optimizers import SGD
import matplotlib.pyplot as plt
import numpy as np
# 载入数据集
mnist = tf.keras.datasets.mnist
# 载入训练集和测试集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对训练集和测试集的数据进行归一化处理,有助于提升模型训练速度
x_train, x_test = x_train / 255.0, x_test / 255.0
# 把训练集和测试集的标签转为独热编码
y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)
# 模型定义,model1使用Dropout
# Dropout(0.4)表示隐藏层40%神经元不工作
model1 = Sequential([
Flatten(input_shape=(28, 28)),
Dense(units=200,activation='tanh'),
Dropout(0.4),
Dense(units=100,activation='tanh'),
Dropout(0.4),
Dense(units=10,activation='softmax')
])
# 在定义一个一模一样的模型用于对比测试,model2不使用Dropout
# Dropout(0)表示隐藏层所有神经元都工作,相当于没有Dropout
model2 = Sequential([
Flatten(input_shape=(28, 28)),
Dense(units=200,activation='tanh'),
Dropout(0),
Dense(units=100,activation='tanh'),
Dropout(0),
Dense(units=10,activation='softmax')
])
# sgd定义随机梯度下降法优化器
# loss='categorical_crossentropy'定义交叉熵代价函数
# metrics=['accuracy']模型在训练的过程中同时计算准确率
sgd = SGD(0.2)
model1.compile(optimizer=sgd,
loss='categorical_crossentropy',
metrics=['accuracy'])
model2.compile(optimizer=sgd,
loss='categorical_crossentropy',
metrics=['accuracy'])
# 传入训练集数据和标签训练模型
# 周期大小为30(把所有训练集数据训练一次称为训练一个周期)
epochs = 30
# 批次大小为32(每次训练模型传入32个数据进行训练)
batch_size=32
# validation_data设置验证集数据
# 先训练model1
history1 = model1.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test,y_test))
# 再训练model2
history2 = model2.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test,y_test))
训练过程:
Train on 60000 samples, validate on 10000 samples Epoch 1/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.4173 - accuracy: 0.8728 - val_loss: 0.2200 - val_accuracy: 0.9337 Epoch 2/30 60000/60000 [==============================] - 5s 78us/sample - loss: 0.2786 - accuracy: 0.9171 - val_loss: 0.1616 - val_accuracy: 0.9516 Epoch 3/30 60000/60000 [==============================] - 4s 73us/sample - loss: 0.2384 - accuracy: 0.9293 - val_loss: 0.1603 - val_accuracy: 0.9519 Epoch 4/30 60000/60000 [==============================] - 4s 74us/sample - loss: 0.2182 - accuracy: 0.9347 - val_loss: 0.1393 - val_accuracy: 0.9577 Epoch 5/30 60000/60000 [==============================] - 4s 74us/sample - loss: 0.2014 - accuracy: 0.9400 - val_loss: 0.1257 - val_accuracy: 0.9626 Epoch 6/30 60000/60000 [==============================] - 5s 75us/sample - loss: 0.1881 - accuracy: 0.9453 - val_loss: 0.1236 - val_accuracy: 0.9651 Epoch 7/30 60000/60000 [==============================] - 5s 83us/sample - loss: 0.1748 - accuracy: 0.9483 - val_loss: 0.1107 - val_accuracy: 0.9670 Epoch 8/30 60000/60000 [==============================] - 6s 104us/sample - loss: 0.1683 - accuracy: 0.9494 - val_loss: 0.1131 - val_accuracy: 0.9662 Epoch 9/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1597 - accuracy: 0.9517 - val_loss: 0.1066 - val_accuracy: 0.9677 Epoch 10/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1534 - accuracy: 0.9541 - val_loss: 0.0945 - val_accuracy: 0.9709 Epoch 11/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1511 - accuracy: 0.9547 - val_loss: 0.1054 - val_accuracy: 0.9674 Epoch 12/30 60000/60000 [==============================] - 6s 97us/sample - loss: 0.1481 - accuracy: 0.9548 - val_loss: 0.0930 - val_accuracy: 0.9730 Epoch 13/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1406 - accuracy: 0.9586 - val_loss: 0.0937 - val_accuracy: 0.9707 Epoch 14/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1381 - accuracy: 0.9588 - val_loss: 0.0904 - val_accuracy: 0.9735 Epoch 15/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1348 - accuracy: 0.9597 - val_loss: 0.0934 - val_accuracy: 0.9724 Epoch 16/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1304 - accuracy: 0.9614 - val_loss: 0.0865 - val_accuracy: 0.9747 Epoch 17/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1262 - accuracy: 0.9628 - val_loss: 0.0871 - val_accuracy: 0.9745 Epoch 18/30 60000/60000 [==============================] - 6s 96us/sample - loss: 0.1255 - accuracy: 0.9628 - val_loss: 0.0856 - val_accuracy: 0.9735 Epoch 19/30 60000/60000 [==============================] - 6s 100us/sample - loss: 0.1248 - accuracy: 0.9616 - val_loss: 0.0826 - val_accuracy: 0.9747 Epoch 20/30 60000/60000 [==============================] - 6s 94us/sample - loss: 0.1180 - accuracy: 0.9651 - val_loss: 0.0847 - val_accuracy: 0.9752 Epoch 21/30 60000/60000 [==============================] - 6s 94us/sample - loss: 0.1163 - accuracy: 0.9648 - val_loss: 0.0869 - val_accuracy: 0.9747 Epoch 22/30 60000/60000 [==============================] - 6s 94us/sample - loss: 0.1171 - accuracy: 0.9650 - val_loss: 0.0813 - val_accuracy: 0.9764 Epoch 23/30 60000/60000 [==============================] - 6s 94us/sample - loss: 0.1160 - accuracy: 0.9647 - val_loss: 0.0872 - val_accuracy: 0.9746 Epoch 24/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1100 - accuracy: 0.9664 - val_loss: 0.0850 - val_accuracy: 0.9759 Epoch 25/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1095 - accuracy: 0.9671 - val_loss: 0.0815 - val_accuracy: 0.9769 Epoch 26/30 60000/60000 [==============================] - 6s 96us/sample - loss: 0.1087 - accuracy: 0.9668 - val_loss: 0.0799 - val_accuracy: 0.9774 Epoch 27/30 60000/60000 [==============================] - 6s 96us/sample - loss: 0.1084 - accuracy: 0.9674 - val_loss: 0.0811 - val_accuracy: 0.9779 Epoch 28/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1055 - accuracy: 0.9683 - val_loss: 0.0794 - val_accuracy: 0.9761 Epoch 29/30 60000/60000 [==============================] - 6s 98us/sample - loss: 0.1030 - accuracy: 0.9689 - val_loss: 0.0803 - val_accuracy: 0.9767 Epoch 30/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1036 - accuracy: 0.9682 - val_loss: 0.0770 - val_accuracy: 0.9777 Train on 60000 samples, validate on 10000 samples Epoch 1/30 60000/60000 [==============================] - 6s 99us/sample - loss: 0.2536 - accuracy: 0.9230 - val_loss: 0.1502 - val_accuracy: 0.9537 Epoch 2/30 60000/60000 [==============================] - 6s 94us/sample - loss: 0.1172 - accuracy: 0.9641 - val_loss: 0.1013 - val_accuracy: 0.9688 Epoch 3/30 60000/60000 [==============================] - 6s 94us/sample - loss: 0.0809 - accuracy: 0.9757 - val_loss: 0.1021 - val_accuracy: 0.9659 Epoch 4/30 60000/60000 [==============================] - 6s 94us/sample - loss: 0.0598 - accuracy: 0.9816 - val_loss: 0.0958 - val_accuracy: 0.9699 Epoch 5/30 60000/60000 [==============================] - 6s 93us/sample - loss: 0.0457 - accuracy: 0.9857 - val_loss: 0.0867 - val_accuracy: 0.9749 Epoch 6/30 60000/60000 [==============================] - 6s 93us/sample - loss: 0.0353 - accuracy: 0.9892 - val_loss: 0.0729 - val_accuracy: 0.9770 Epoch 7/30 60000/60000 [==============================] - 6s 98us/sample - loss: 0.0244 - accuracy: 0.9932 - val_loss: 0.0774 - val_accuracy: 0.9762 Epoch 8/30 60000/60000 [==============================] - 6s 96us/sample - loss: 0.0191 - accuracy: 0.9947 - val_loss: 0.0688 - val_accuracy: 0.9782 Epoch 9/30 60000/60000 [==============================] - 6s 96us/sample - loss: 0.0141 - accuracy: 0.9966 - val_loss: 0.0946 - val_accuracy: 0.9702 Epoch 10/30 60000/60000 [==============================] - 7s 111us/sample - loss: 0.0097 - accuracy: 0.9978 - val_loss: 0.0704 - val_accuracy: 0.9785 Epoch 11/30 60000/60000 [==============================] - 6s 107us/sample - loss: 0.0058 - accuracy: 0.9991 - val_loss: 0.0629 - val_accuracy: 0.9813 Epoch 12/30 60000/60000 [==============================] - 6s 99us/sample - loss: 0.0043 - accuracy: 0.9995 - val_loss: 0.0684 - val_accuracy: 0.9800 Epoch 13/30 60000/60000 [==============================] - 6s 98us/sample - loss: 0.0030 - accuracy: 0.9998 - val_loss: 0.0646 - val_accuracy: 0.9808 Epoch 14/30 60000/60000 [==============================] - 6s 98us/sample - loss: 0.0022 - accuracy: 0.9999 - val_loss: 0.0643 - val_accuracy: 0.9815 Epoch 15/30 60000/60000 [==============================] - 6s 106us/sample - loss: 0.0017 - accuracy: 1.0000 - val_loss: 0.0678 - val_accuracy: 0.9804 Epoch 16/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0660 - val_accuracy: 0.9811 Epoch 17/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.0013 - accuracy: 1.0000 - val_loss: 0.0667 - val_accuracy: 0.9812 Epoch 18/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.0011 - accuracy: 1.0000 - val_loss: 0.0670 - val_accuracy: 0.9814 Epoch 19/30 60000/60000 [==============================] - 6s 96us/sample - loss: 0.0010 - accuracy: 1.0000 - val_loss: 0.0668 - val_accuracy: 0.9814 Epoch 20/30 60000/60000 [==============================] - 6s 95us/sample - loss: 9.3235e-04 - accuracy: 1.0000 - val_loss: 0.0676 - val_accuracy: 0.9817 Epoch 21/30 60000/60000 [==============================] - 6s 95us/sample - loss: 8.5067e-04 - accuracy: 1.0000 - val_loss: 0.0673 - val_accuracy: 0.9815 Epoch 22/30 60000/60000 [==============================] - 6s 95us/sample - loss: 7.8290e-04 - accuracy: 1.0000 - val_loss: 0.0688 - val_accuracy: 0.9813 Epoch 23/30 60000/60000 [==============================] - 6s 95us/sample - loss: 7.2826e-04 - accuracy: 1.0000 - val_loss: 0.0682 - val_accuracy: 0.9814 Epoch 24/30 60000/60000 [==============================] - 6s 97us/sample - loss: 6.8046e-04 - accuracy: 1.0000 - val_loss: 0.0691 - val_accuracy: 0.9811 Epoch 25/30 60000/60000 [==============================] - 5s 91us/sample - loss: 6.3994e-04 - accuracy: 1.0000 - val_loss: 0.0696 - val_accuracy: 0.9812 Epoch 26/30 60000/60000 [==============================] - 5s 91us/sample - loss: 5.9906e-04 - accuracy: 1.0000 - val_loss: 0.0699 - val_accuracy: 0.9812 Epoch 27/30 60000/60000 [==============================] - 6s 92us/sample - loss: 5.6810e-04 - accuracy: 1.0000 - val_loss: 0.0696 - val_accuracy: 0.9815 Epoch 28/30 60000/60000 [==============================] - 6s 98us/sample - loss: 5.3810e-04 - accuracy: 1.0000 - val_loss: 0.0707 - val_accuracy: 0.9812 Epoch 29/30 60000/60000 [==============================] - 6s 96us/sample - loss: 5.1041e-04 - accuracy: 1.0000 - val_loss: 0.0707 - val_accuracy: 0.9811 Epoch 30/30 60000/60000 [==============================] - 6s 96us/sample - loss: 4.8516e-04 - accuracy: 1.0000 - val_loss: 0.0712 - val_accuracy: 0.9819
这里是用两个模型对比的,所以训练过程包含了两个模型的结果。
# 画出model1验证集准确率曲线图
plt.plot(np.arange(epochs),history1.history['val_accuracy'],c='b',label='Dropout')
# 画出model2验证集准确率曲线图
plt.plot(np.arange(epochs),history2.history['val_accuracy'],c='y',label='FC')
# 图例
plt.legend()
# x坐标描述
plt.xlabel('epochs')
# y坐标描述
plt.ylabel('accuracy')
# 显示图像
plt.show()
模型训练结果前 1-30 周期是使用了 Dropout 的结果,后面的 1-30 周期是没有使用 Dropout 的结果。观察结果我们发现使用了 Dropout 之后训练集准确率和验证集的准确率相差并不是很大,所以能看出 Dropout 确实是可以起到抵抗过拟合的作用。我们还可以发现一个有趣的现象就是前 1-30 周期 model1 的验证集准确率还高于训练集的准确率,这是因为模 型在计算训练集准确率的时候模型还在使用 Dropout,在计算验证集准确率的时候已经不使 用 Dropout 了。使用 Dropout 的时候模型的准确率会稍微降低一些。同时我们也可以发现, 不用 Dropout 的 model2 中测试集的准确率看起来比使用 Dropout 的 model1 要更高。
事实上使用 Dropout 之后模型的收敛速度会变慢一些,所以需要更多的训练次数才能得到最好的结果。
这里不用 Dropout 的 model2 验证集训练 30 个周期最高准确率大概 是 98.2%左右;使用 Dropout 的 model1 如果训练足够多的周期,验证集最高准确率可以达 到 98.8%左右。
- 点赞
- 收藏
- 关注作者
评论(0)