pytorch小难点

举报
yd_234306724 发表于 2020/12/27 23:17:25 2020/12/27
【摘要】 总结pytorch语法 scatter_(input, dim, index, src) >>> x = torch.rand(2, 5) >>> x 0.4319 0.6500 0.4080 0.8760 0.2355 0.2609 0.4711 0.8486 0.8573 0.1029 [torch.FloatTensor of size ...

总结pytorch语法

scatter_(input, dim, index, src)

>>> x = torch.rand(2, 5)
>>> x

 0.4319  0.6500  0.4080  0.8760  0.2355
 0.2609  0.4711  0.8486  0.8573  0.1029
[torch.FloatTensor of size 2x5]

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

LongTensor的shape刚好与x的shape对应,也就是LongTensor每个index指定x中一个数据的填充位置。dim=0,表示按行填充,主要理解按行填充。举例LongTensor中的第0行第2列index=2,表示在第2行(从0开始)进行填充填充,对应到zeros(3, 5)中就是位置(2,2)。所以此处要求zeros(3, 5)的列数要与x列数相同,而LongTensor中的index最大值应与zeros(3, 5)行数相一致,其于地方填0。

>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

 0.4319  0.4711  0.8486  0.8760  0.2355
 0.0000  0.6500  0.0000  0.8573  0.0000
 0.2609  0.0000  0.4080  0.0000  0.1029
[torch.FloatTensor of size 3x5]

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

同上理,可以把1.23看成[[1.23], [1.23]]。此处按列填充,LongTensor中的index=2对应zeros(2, 4)的(0,2)位置,其于地方填0。

>>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)
>>> z

 0.0000  0.0000  1.2300  0.0000
 0.0000  0.0000  0.0000  1.2300
[torch.FloatTensor of size 2x4]

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

zip

  # >>>a = [1,2,3] >>> b = [4,5,6] >>> c = [4,5,6,7,8] >>> zipped = zip(a,b) # 打包为元组的列表 [(1, 4), (2, 5), (3, 6)] >>> zip(a,c) # 元素个数与最短的列表一致 [(1, 4), (2, 5), (3, 6)] >>> zip(*zipped) # 与 zip 相反,*zipped 可理解为解压,返回二维矩阵式 [(1, 2, 3), (4, 5, 6)]

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

torch.randn(*sizes, out=None)

返回一个张量,包含了从标准正态分布(均值为0,方差为1,即高斯白噪声)中抽取的一组随机数。张量的形状由参数sizes定义。

torch.randn(2, 3)
0.5419 0.1594 -0.0413
-2.7937 0.9534 0.4561
[torch.FloatTensor of size 2x3]


  
 
  • 1
  • 2
  • 3
  • 4
  • 5

torch.rand(*sizes, out=None)

返回一个张量,包含了从区间[0, 1)的均匀分布中抽取的一组随机数。张量的形状由参数sizes定义

torch.rand(2, 3)
0.0836 0.6151 0.6958
0.6998 0.2560 0.0139
[torch.FloatTensor of size 2x3]


  
 
  • 1
  • 2
  • 3
  • 4
  • 5

register_buffer()

model.state_dict() 的返回对象是一个 OrderDict ,它以键值对的形式包含模型中需要保存下来的参数,例如:

class MyModule(nn.Module): def __init__(self, input_size, output_size): super(MyModule, self).__init__() self.lin = nn.Linear(input_size, output_size) def forward(self, x): return self.lin(x)

module = MyModule(4, 2)
print(module.state_dict())

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

在这里插入图片描述
模型中的参数就是线性层的 weight 和 bias.

Parameter 和 buffer,一种是反向传播需要被optimizer更新的,称之为 parameter,一种是反向传播不需要被optimizer更新,称之为 buffer
第一种参数我们可以通过 model.parameters() 返回;第二种参数我们可以通过 model.buffers() 返回。因为我们的模型保存的是 state_dict 返回的 OrderDict,所以这两种参数不仅要满足是否需要被更新的要求,还需要被保存到OrderDict。

torch.bmm

计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,h),注意两个tensor的维度必须为3.

>>> cc=torch.randn((2,2,5))
>>>print(cc)
tensor([[[ 1.4873, -0.7482, -0.6734, -0.9682,  1.2869], [ 0.0550, -0.4461, -0.1102, -0.0797, -0.8349]], [[-0.6872,  1.1920, -0.9732,  0.4580,  0.7901], [ 0.3035,  0.2022,  0.8815,  0.9982, -1.1892]]])
>>>dd=torch.reshape(cc,(2,5,2))
>>> print(dd)
tensor([[[ 1.4873, -0.7482], [-0.6734, -0.9682], [ 1.2869,  0.0550], [-0.4461, -0.1102], [-0.0797, -0.8349]], [[-0.6872,  1.1920], [-0.9732,  0.4580], [ 0.7901,  0.3035], [ 0.2022,  0.8815], [ 0.9982, -1.1892]]])
>>>e=torch.bmm(cc,dd)
>>> print(e)
tensor([[[ 2.1787, -1.3931], [ 0.3425,  1.0906]], [[-0.5754, -1.1045], [-0.6941,  3.0161]]])
 >>> e.size()
torch.Size([2, 2, 2])

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

masked_fill

mask中数值为1的位置在 tensor a 中相应的位置填充成了-1000000000, mask中数值为0的位置在tensor a 中填充-1000000000



>>> t = torch.randn(3,2)
>>> t
tensor([[-0.9180, -0.4654], [ 0.9866, -1.3063], [ 1.8359,  1.1607]])
>>> m = torch.randint(0,2,(3,2))
>>> m
tensor([[0, 1], [1, 1], [1, 0]])
>>> m == 0
tensor([[ True, False], [False, False], [False,  True]])
>>> t.masked_fill(m == 0, -1e9)
tensor([[-1.0000e+09, -4.6544e-01], [ 9.8660e-01, -1.3063e+00],



  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

文章来源: blog.csdn.net,作者:快了的程序猿小可哥,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/qq_35914625/article/details/111409869

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。