PyTorch: 计算图与动态图机制

举报
timerring 发表于 2022/10/31 09:34:19 2022/10/31
【摘要】 本文已收录于Pytorch系列专栏: Pytorch入门与实践 专栏旨在详解Pytorch,精炼地总结重点,面向入门学习者,掌握Pytorch框架,为数据分析,机器学习及深度学习的代码能力打下坚实的基础。免费订阅,持续更新。 计算图计算图是用来描述运算的有向无环图计算图有两个主要元素:结点 Node边 Edge结点表示数据:如向量,矩阵,张量边表示运算:如加减乘除卷积等用计算图表示:y ...

本文已收录于Pytorch系列专栏: Pytorch入门与实践 专栏旨在详解Pytorch,精炼地总结重点,面向入门学习者,掌握Pytorch框架,为数据分析,机器学习及深度学习的代码能力打下坚实的基础。免费订阅,持续更新。

计算图

计算图是用来描述运算的有向无环图

计算图有两个主要元素:

  • 结点 Node

  • 边 Edge

结点表示数据:如向量,矩阵,张量

边表示运算:如加减乘除卷积等

用计算图表示:y = (x+ w) * (w+1)
a = x + w
b = w + 1
y = a * b

image-20221007140501247

计算图与梯度求导

y = (x+ w) * (w+1)
a = x + w
b = w + 1
y = a * b

image-20221007142316040

y w = y a a w + y b b w = b 1 + a 1 = b + a = ( w + 1 ) + ( x + w ) = 2 w + x + 1 = 2 1 + 2 + 1 = 5 \begin{aligned} \frac{\partial y}{\partial w} &=\frac{\partial y}{\partial a} \frac{\partial a}{\partial w}+\frac{\partial y}{\partial b} \frac{\partial b}{\partial w} \\ &=b * 1+a * 1 \\ &=b+a \\ &=(w+1)+(x+w) \\ &=2 * w+x+1 \\ &=2 * 1+2+1=5 \end{aligned}

可见,对于变量w的求导过程就是寻找它在计算图中的所有路径的求导之和。

code:

import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)     # retain_grad()
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward()
print(w.grad)
tensor([5.])

计算图与梯度求导
y = (x+ w) * (w+1)

叶子结点 :用户创建的结点称为叶子结点,如 X 与 W

is_leaf: 指示张量是否为叶子结点

叶子节点的作用是标志存储叶子节点的梯度,而清除在反向传播过程中的变量的梯度,以达到节省内存的目的。

当然,如果想要保存过程中变量的梯度值,可以采用retain_grad()

grad_fn: 记录创建该张量时所用的方法(函数)

  • y.grad_fn= <MulBackward0>
  • a.grad_fn= <AddBackward0>
  • b.grad_fn= <AddBackward0>

image-20221007142938198

PyTorch的动态图机制

根据计算图搭建方式,可将计算图分为动态图静态图

  • 动态图

    运算与搭建同时进行

    灵活 易调节

    例如动态图 PyTorch:

    image-20221007144304367

  • 静态

    先搭建图, 后运算

    高效 不灵活。

    静态图 TensorFlow

    image-20221007144319338

【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。