《AI安全之对抗样本入门》—3.5 MXNet
3.5 MXNet
MXNet是亚马逊开发的深度学习库,它拥有类似于Theano和TensorFlow的数据流图,并且可以在常见的硬件平台上运行。MXNet还提供了R、C++、Scala等语言的接口。我们以解决经典的手写数字识别的问题为例,介绍MXNet的基本使用方法,代码路径为:
https://github.com/duoergun0729/adversarial_examples/blob/master/code/2-mxnet.ipynb
1. 加载相关库
加载处理经典的手写数字识别问题相关的Python库:
import mxnet as mx
import logging
2. 加载数据集
MXNet中也针对常见的数据集进行了封装,免去了用户手工下载的过程并简化了预处理的过程:
mnist = mx.test_utils.get_mnist()
batch_size = 128
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'],
batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'],
batch_size)
3. 定义网络结构
使用与Keras几乎完全相同的网络结构,只不过省略了Dropout层。
data = mx.sym.var('data')
data = mx.sym.flatten(data=data)
#全连接
fc1 = mx.sym.FullyConnected(data=data, num_hidden = 512)
act1 = mx.sym.Activation(data=fc1, act_type="relu")
fc2 = mx.sym.FullyConnected(data=act1, num_hidden = 512)
act2 = mx.sym.Activation(data=fc2, act_type="relu")
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)
# softmax输出
mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax')
mlp_model = mx.mod.Module(symbol=mlp, context=ctx)
可视化网络结构,如图3-8所示,值得一提的是MXNet自带的可视化工具非常便于使用。
import matplotlib.pyplot as plt
mx.viz.plot_network(mlp).view()
图3-8 MXNet处理MNIST的网络结构图
4. 定义损失函数和优化器
损失函数使用交叉熵CrossEntropyLoss,优化器使用sgd。
5. 训练与验证
MXNet的训练和验证过程是分开的,训练阶段加载优化器的配置,可以指定每训练100个批次,打印中间结果。
mlp_model.fit(train_iter,
eval_data=val_iter,
optimizer='sgd',
optimizer_params={'learning_rate':0.1},
eval_metric='acc',
batch_end_callback = mx.callback.Speedometer(batch_size, 100),
num_epoch=20)
经过20轮的训练后,在测试集上验证准确度。
test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
acc = mx.metric.Accuracy()
mlp_model.score(test_iter, acc)
print(acc)
最终在测试集上准确度达到了97.62%。
INFO:root:Epoch[19] Train-accuracy=0.996438
INFO:root:Epoch[19] Time cost=4.017 INFO:root:Epoch[19] Validation-accuracy=0.976167
EvalMetric: {'accuracy': 0.9761669303797469}
MXNet保存的模型文件后缀为parms。
mlp_model.save_checkpoint('models/mxnet.parms',20)
- 点赞
- 收藏
- 关注作者
评论(0)