《AI安全之对抗样本入门》—3.5 MXNet

举报
华章计算机 发表于 2019/06/17 18:21:53 2019/06/17
【摘要】 本节书摘来自华章计算机《AI安全之对抗样本入门》一书中的第3章,第3.5节,作者是兜哥。

3.5 MXNet

MXNet是亚马逊开发的深度学习库,它拥有类似于Theano和TensorFlow的数据流图,并且可以在常见的硬件平台上运行。MXNet还提供了R、C++、Scala等语言的接口。我们以解决经典的手写数字识别的问题为例,介绍MXNet的基本使用方法,代码路径为:

https://github.com/duoergun0729/adversarial_examples/blob/master/code/2-mxnet.ipynb

1. 加载相关库

加载处理经典的手写数字识别问题相关的Python库:

import mxnet as mx

import logging

2. 加载数据集

MXNet中也针对常见的数据集进行了封装,免去了用户手工下载的过程并简化了预处理的过程:

mnist = mx.test_utils.get_mnist()

batch_size = 128

train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'],

batch_size, shuffle=True)

val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'],

batch_size)

3. 定义网络结构

使用与Keras几乎完全相同的网络结构,只不过省略了Dropout层。

data = mx.sym.var('data')

data = mx.sym.flatten(data=data)

#全连接

fc1  = mx.sym.FullyConnected(data=data, num_hidden = 512)

act1 = mx.sym.Activation(data=fc1, act_type="relu")

fc2  = mx.sym.FullyConnected(data=act1, num_hidden = 512)

act2 = mx.sym.Activation(data=fc2, act_type="relu")

fc3  = mx.sym.FullyConnected(data=act2, num_hidden=10)

# softmax输出

mlp  = mx.sym.SoftmaxOutput(data=fc3, name='softmax')

mlp_model = mx.mod.Module(symbol=mlp, context=ctx)

可视化网络结构,如图3-8所示,值得一提的是MXNet自带的可视化工具非常便于使用。

import matplotlib.pyplot as plt

mx.viz.plot_network(mlp).view()

 

 image.png

图3-8 MXNet处理MNIST的网络结构图

4. 定义损失函数和优化器

损失函数使用交叉熵CrossEntropyLoss,优化器使用sgd。

5. 训练与验证

MXNet的训练和验证过程是分开的,训练阶段加载优化器的配置,可以指定每训练100个批次,打印中间结果。

mlp_model.fit(train_iter,

              eval_data=val_iter, 

              optimizer='sgd', 

              optimizer_params={'learning_rate':0.1},

              eval_metric='acc',

              batch_end_callback = mx.callback.Speedometer(batch_size, 100),

              num_epoch=20)

经过20轮的训练后,在测试集上验证准确度。

test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

acc = mx.metric.Accuracy()

mlp_model.score(test_iter, acc)

print(acc)

最终在测试集上准确度达到了97.62%。

INFO:root:Epoch[19] Train-accuracy=0.996438

INFO:root:Epoch[19] Time cost=4.017 INFO:root:Epoch[19] Validation-accuracy=0.976167

EvalMetric: {'accuracy': 0.9761669303797469}

MXNet保存的模型文件后缀为parms。

mlp_model.save_checkpoint('models/mxnet.parms',20)


【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。