《神经网络与PyTorch实战》——2.2.2 迷你AlphaGo的完整实现

举报
华章计算机 发表于 2019/06/05 19:44:57 2019/06/05
【摘要】 本书摘自《神经网络与PyTorch实战》——书中第2章,第2.2.2节,作者是肖智清。

2.2.2 迷你AlphaGo的完整实现

  在这一节我们来看一个基于PyTorch的迷你AlphaGo的完整实现。完整实现分为4个部分。

* 确定神经网络结构。这个部分确定神经网络中有几个神经元,神经元之间是怎么连接的。

* 构造测试函数。这个部分确定要针对哪个函数来选择人工神经元的权重。显然,函数不同,设置的权重值就不同。

* 设置人工神经元的权重值。根据测试函数,让PyTorch自动设置合适的神经元权重值。完成这步后,神经网络就完全确定了。

* 测试人工神经网络的效果。这步验证之前实现的人工神经网络能够很好地完成“迷你AlphaGo”的功能。

  接下来依次来看这四部分的完整代码。将这四部分完整代码依次拼接在一起,就得到了完整的程序(而不需要再添加其他任何代码)。代码下载地址参见本书前言部分。

注意:如果你觉得这些代码难以理解,请不用担心。这里只是希望你有一个初步的概念。本书的后续章节会逐步介绍如何理解或编写这些代码。等你读完本书,你也能写出这样的代码。

  首先来看第一部分代码。如代码清单2-1所示,这部分代码的主要功能是确定神经网络的结构。具体而言,以“import”和“from”开头的第1行代码告诉计算机我要用PyTorch了,让机器做好准备。接下来用一个序列构造神经网络。在这个序列中,有表示连接的Linear类,也有表示非线性运算的ReLU类。在这个序列中一共有3个Linear类实例,说明这个神经网络有3层。第一个Linear类实例用参数3和8构造,这两个参数说明每个神经元都有3个输入,一共有8个神经元。这个序列中有两个ReLU类实例,也就是说,其中两个层的神经元的非线性函数都是。

  你可能会有疑问:为什么这个神经网络的最后一层没有使用非线性函数?这是因为,我们希望将要制作的“迷你AlphaGo”应用既能输出的结果,也能输出的结果。如果在最后的输出中设置了,那就不可能输出的结果。因此这里最后一层的神经元没有使用非线性函数。

代码清单2-1 神经网络结构搭建代码

     from torch.nn import Linear, ReLU, Sequential

     net = Sequential(

             Linear(3, 8), # 第1层有8个神经元

             ReLU(), # 第1层神经元的非线性函数是max(·,0)

             Linear(8, 8), # 第2层有8个神经元

             ReLU(), # 第2层的神经元的非线性函数是max(·,0)

             Linear(8, 1), # 第3层有1个神经元

             )

  接下来看第二部分代码。在这部分中,将定义函数。当然,这个函数的形式可以是任意的,而且神经网络不需要知道这个函数的具体形式。但是,神经网络权重的确定和函数有关。函数不同,神经网络就需要不同的权重值。因此,给出函数是有必要的。这部分代码如代码清单2-2所示。这里的代码需要用到一些PyTorch的编程知识,我们目前只需要知道这段代码定义了函数

其中,。

代码清单2-2 函数的定义

     def g(x, y):

         x0, x1, x2 = x[:, 0] ** 0, x[:, 1] ** 1, x[:, 2] ** 2

         y0 = y[:, 0]

         return (x0 + x1 + x2) * y0 - y0 * y0 - x0 * x1 * x2

  接下来看第三部分代码。这部分代码见代码清单2-3。在这部分代码中,在PyTorch的帮助下为神经网络中的每个神经元找到合适的权重。在这段代码中,我们没有显式地为每个神经元指定权重值,而是使用了一个优化器。代码第3行构造了优化器optimizer。这个优化器每次可以改良所有权重值,但是这个改良不是一步到位的。我们需要让优化器反复改良很多次,才能让神经网络的所有权重都合适。在这段代码中以“for”开头的语句说明整个过程需要循环很多次(实际上是1000次),而后面缩进的语句都是要循环的内容。在循环的内容中,我们需要告诉优化器每次改良的依据是什么。因此后面缩进的代码先告诉优化器改良的依据,然后通过“optimizer.step()”语句完成权重的改良。完成循环后,神经网络中所有的权重值就都比较合适了。这时候我们就训练好了人工神经网络。

代码清单2-3 使用优化器确定神经元的权重值

     import torch 

     from torch.optim import Adam

     

     optimizer = Adam(net.parameters())

     for step in range(1000):

         optimizer.zero_grad()

         x = torch.randn(1000, 3)

         y = net(x)

         outputs = g(x, y)

         loss = -torch.sum(outputs)

         loss.backward()

         optimizer.step()

         if step % 100 == 0:

             print ('第{}次迭代损失 = {}'.format(step, loss))

  那么,实现好的人工神经网络是不是正确完成了“迷你AlphaGo”的任务目标呢?第四部分的代码就来验证。这部分代码见代码清单2-4。这部分的代码又可以细分为三小部分。

* 生成测试数据。随机生成一些数据作为输入,我们将看到人工神经网络是不是能很好地计算出对应的。生成的输入数据被以“print”开头的语句输出在电脑屏幕上。

* 调用人工神经网络,看看在给定输入下神经网络的输出是什么。神经网络的输出和对应的函数的值被打印了出来。

* 计算理论最优结果作为参考。这部分并没有用到人工神经网络,只是用这个结果和神经网络的输出做比较。如果的具体形式为

,那么这个实际上是一个关于的二次函数,它的开口向下。利用二次函

数的知识,可以知道,这个二次函数在处取到最大值。因此,理论上的最佳输出为。这部分代码将理论上的最佳输出和最佳

输出对应的函数的值打印了出来。

代码清单2-4 验证生成的人工神经网络

     # 生成测试数据

     x_test = torch.randn(2, 3)

     print ('测试输入: {}'.format(x_test))

     

     # 查看神经网络的计算结果

     y_test = net(x_test)

     print ('人工神经网络计算结果: {}'.format(y_test))

     print ('g的值: {}'.format(g(x_test, y_test)))

     

     # 根据理论计算参考答案

     def argmax_g(x):

         x0, x1, x2 = x[:, 0] ** 0, x[:, 1] ** 1, x[:, 2] ** 2

         return 0.5 * (x0 + x1 + x2)[:, None]

     yref_test = argmax_g(x_test)

     print ('理论最优值: {}'.format(yref_test))

     print ('g的值: {}'.format(g(x_test, yref_test)))

  由于第四部分验证代码的输入是随机确定的,所以每次运行的输入和输出都不一样。下面给出某次运行结果:

     测试输入: tensor([[ 0.2487,  0.3399, -0.4967],

             [ 1.0140,  0.1038, -0.1002]])

     人工神经网络计算结果: tensor([[ 0.8289],

             [ 0.4975]])

     g的值: tensor([ 0.5442,  0.3056])

     理论最优值: tensor([[ 0.7933],

             [ 0.5569]])

     g的值: tensor([ 0.5455,  0.3091])

  比较神经网络的输出结果和理论计算结果,可以断定,我们的人工神经网络已经正确地输出了最优结果,实现了“迷你AlphaGo”的功能。


【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。