《神经网络与PyTorch实战》——3.3 组织张量的元素

举报
华章计算机 发表于 2019/06/05 20:00:49 2019/06/05
【摘要】 本书摘自《神经网络与PyTorch实战》——书中第3章,第3.3.1节,作者是肖智清。

3.3 组织张量的元素

3.3.1 重排张量元素

  本节介绍如何在不改变张量元素个数和各元素的值的情况下改变张量的大小。本节将涉及以下torch.Tensor类的成员方法。

* reshape()、squeeze() 和unsqueeze():这些成员不会改变元素的实际位置;

* permute()、transpose() 和t():这些成员可能改变元素的实际位置。

  接下来详细看看这些成员。

  首先来看torch.Tensor类的成员方法reshape()。reshape() 方法的参数是多个int类型的值。如果想要把一个张量的大小改成,那就让作为reshape () 方法的个参数。

  我们来看个例子。代码清单3-5首先构造了一个大小为 (12,) 的一维张量tc。接着,张量tc被改变成大小为 (3, 2, 2) 的张量t322。再后来,张量t322被改变成大小为 (4, 3) 的张量t43。在这个例子中,各个张量的元素个数都相同,张量中元素的值都是0~11这12个数。但是,各个张量的维度和大小都不同。

代码清单3-5 使用reshape()?在不改变元素个数和各元素的值的情况下改变张量大小

     tc = torch.arange(12) # 张量大小 (12,)

     print('tc = {}'.format(tc))

     t322 = tc.reshape(3, 2, 2) # 张量大小 (3, 2, 2)

     print('t322 = {}'.format(t322))

     t43 = t322.reshape(4, 3) # 张量大小 (4, 3)

     print('t43 = {}'.format(t43))

  在reshape() 参数里,1个维度的大小用代替。如果某个维度的大小用代替,那么该函数就会根据张量总元素的个数和其他各维度的元素个数自动计算这个用指定的维度的大小。例如,运行下面的代码会得到大小为的张量,其中3就是通过计算得到的:

     torch.arange(24).reshape(2, -1, 4) # 大小 = (2, 3, 4)

  部分reshape() 操作可以使用squeeze() 和unsqueeze() 代替。squeeze() 可以消除张量大小中大小为1的维度。例如:

     t = torch.arange(24).reshape(2, 1, 3, 1, 4) # 大小 = (2, 1, 3, 1, 4)

     t.squeeze() # 大小 = (2, 3, 4)

而unsqueeze() 正好相反,它可以增加一些大小为0的维度。增加的维度的位置由关键字参数dims指定。例如:

     t = torch.arange(24).reshape(2, 3, 4) # 大小 = (2, 3, 4)

     t.unsqueeze(dim=2) # 大小 = (2, 3, 1, 4)

  接下来介绍张量的交换(permute)。张量的交换是将张量的各维度重新排列。假设原来的张量大小为,其按照交换得到的张量的大小为。例如:

     t = torch.arange(24).reshape(1, 2, 3, 4) # 大小 = (1, 2, 3, 4)

     t.permute(dims=[2, 0, 1, 3] # 大小 = (3, 1, 2, 4)

  某些张量的交换还可以用转置(transpose)函数实现。张量的转置是将两个维度的索引互换。例如,如果将大小为的张量的第维和第维互换,可以得到大小为的张量。要对张量进行转置,可以使用成员方法transpose()。成员方法transpose() 的两个参数是要交换的两个维度。对于二维张量,还有一个成员方法t(),相当于transpose(0, 1),即将张量的大小从变为。例如,下面的代码就用了这两种方法对张量t12进行了转置,得到同样的结果:

     t12 = torch.tensor([[5., -9.],])

     t21 = t12.transpose(0, 1)

     print('t21 = {}'.format(t21))

     t21 = t12.t()

     print('t21 = {}'.format(t21))


【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。