《AI安全之对抗样本入门》—3.4 PyTorch

举报
华章计算机 发表于 2019/06/17 18:20:10 2019/06/17
【摘要】 本节书摘来自华章计算机《AI安全之对抗样本入门》一书中的第3章,第3.4节,作者是兜哥。

3.4 PyTorch

PyTorch是torch的Python版本,是由Facebook开源的神经网络框架。PyTorch虽然是深度学习框架中的后起之秀,但是发展极其迅猛。PyTorch提供了NumPy风格的Tensor操作,熟悉NumPy操作的用户非常容易上手。我们以解决经典的手写数字识别的问题为例,介绍PyTorch的基本使用方法,代码路径为:

https://github.com/duoergun0729/adversarial_examples/blob/master/code/2-pytorch.ipynb

1. 加载相关库

加载处理经典的手写数字识别问题相关的Python库:

import os

import torch

import torchvision

from torch.autograd import Variable

import torch.utils.data.dataloader as Data

2. 加载数据集

PyTorch中针对常见的数据集进行了封装,免去了用户手工下载的过程并简化了预处理的过程。这里需要特别指出的是,PyTorch中每个Tensor包括输入节点,并且都可以有自己的梯度值,因此训练数据集要设置为train=True,测试数据集要设置为train=False:

train_data = torchvision.datasets.MNIST(

 'dataset/mnist-pytorch', train=True,

transform=torchvision.transforms.ToTensor(), download=True

)

test_data = torchvision.datasets.MNIST(

 'dataset/mnist-pytorch', train=False,

transform=torchvision.transforms.ToTensor()

)

如果需要对数据进行归一化,可以进一步使用transforms.Normalize方法:

transform=transforms.Compose([torchvision.transforms.ToTensor(),

                 torchvision.transforms.Normalize([0.5], [0.5])])

第一次运行该程序时,PyTorch会从互联网直接下载数据集并处理:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

Processing... Done!

3. 定义网络结构

使用与Keras类似的网络结构,即两层隐藏层结构,不过使用BatchNorm层替换了Dropout层,在抵御过拟合的同时加快了训练的收敛速度。在PyTorch中定义网络结构,通常需要继承torch.nn.Module类,重点是在forward中完成前向传播的定义,在init中完成主要网络层的定义:

class Net(torch.nn.Module):

    def __init__(self):

        super(Net, self).__init__()

        self.dense = torch.nn.Sequential(

            #全连接层

            torch.nn.Linear(784, 512),

            #BatchNorm层

            torch.nn.BatchNorm1d(512),

            torch.nn.ReLU(),

            torch.nn.Linear(512, 10),

            torch.nn.ReLU()

        )    def forward(self, x):

        #把输出转换成大小为784的一维向量

        x = x.view(-1, 784)

        x=self.dense(x)

        return torch.nn.functional.log_softmax(x, dim=1)

最后可视化网络结构,细节如图3-7所示。

 image.png

图3-7 PyTorch处理MNIST的网络结构图

4. 定义损失函数和优化器

损失函数使用交叉熵CrossEntropyLoss,优化器使用Adam,优化的对象是全部网络参数:

optimizer = torch.optim.Adam(model.parameters())

loss_func = torch.nn.CrossEntropyLoss()

5. 训练与验证

PyTorch的训练和验证过程是分开的,在训练阶段需要把训练数据进行前向传播后,使用损失函数计算训练数据的真实标签与预测标签之间损失值,然后显示调用反向传递backward(),使用优化器来调整参数,这一操作需要调用optimizer.step():

for i, data in enumerate(train_loader):

            inputs, labels = data

            inputs, labels = inputs.to(device), labels.to(device)

            # 梯度清零

            optimizer.zero_grad()

            # 前向传播

            outputs = model(inputs)

            loss = loss_func(outputs, labels)

            #反向传递

            loss.backward()

            optimizer.step()

每轮训练需要花费较长的时间,为了让训练过程可视化,可以打印训练的中间结果,比如每100个批次打印下平均损失值:

# 每训练100个批次打印一次平均损失值

sum_loss += loss.item()

if (i+1) % 100 == 0:

   print('epoch=%d, batch=%d loss: %.04f'% (epoch + 1, i+1, sum_loss / 100))

        sum_loss = 0.0

验证阶段要手工关闭反向传递,需要通过torch.no_grad()实现:

# 每跑完一次epoch,测试一下准确率进入测试模式,禁止梯度传递

with torch.no_grad():

     correct = 0

     total = 0

     for data in test_loader:

                images, labels = data

                images, labels = images.to(device), labels.to(device)

                outputs = model(images)

                # 取得分最高的那个类

                _, predicted = torch.max(outputs.data, 1)

                total += labels.size(0)

                correct += (predicted == labels).sum()

                print('epoch=%d accuracy=%.02f%%' % (epoch + 1, (100 * correct /

                total)))

经过20轮训练,在测试集上准确度达到了97.00%:

epoch=20, batch=100 loss: 0.0035

epoch=20, batch=200 loss: 0.0049

epoch=20, batch=300 loss: 0.0040

epoch=20, batch=400 loss: 0.0042

epoch=20 accuracy=97.00%

PyTorch保存的模型文件后缀为pth:

torch.save(model.state_dict(), 'models/pytorch-mnist.pth')


【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

举报
请填写举报理由
0/200