《神经网络与PyTorch实战》——3.3.3 张量的扩展和拼接

举报
华章计算机 发表于 2019/06/05 20:03:10 2019/06/05
【摘要】 本书摘自《神经网络与PyTorch实战》——书中第3章,第3.3.3节,作者是肖智清。

3.3.3 张量的扩展和拼接

  本节介绍如何扩展张量,或将多个张量拼接成1个张量,会涉及torch.Tensor类的成员方法repeat()、函数torch.repeat()?和函数torch.cat()。

  torch.Tensor类的成员方法repeat() 可以将张量的内容进行重复,使得张量的大小变大。对于大小为的张量,在使用成员方法repeat() 并给定参数时,可以得到大小为的张量,张量的元素就是将原来张量的元素复制份。

  例如,在代码清单3-9中,首先构造了一个大小为 (1, 2) 的张量t12。接着通过t12.repeat() 得到大小为 (3, 4) 的张量t34。

代码清单3-9 使用repeat()?函数扩大张量

     t12 = torch.tensor([[5., -9.],])

     print('t12 = {}'.format(t12))

     t34 = t12.repeat(3, 2)

     print('t34 = {}'.format(t34))

  torch.cat() 函数可以将多个张量拼接成1个张量。torch.cat() 函数有两个参数。第1个是一个张量的列表或是元组,这个函数就是要将这个列表或是元组里面的所有张量拼接起来。第2个参数指示要将这些张量在哪个维度拼接起来。如果要将这些张量在第维拼接起来,那么被拼接的张量的大小需要满足下列条件:

* 所有张量的维度都完全相同,并且大于。拼接得到的张量的维度也就是其中任意一个张量的维度。

* 所有张量的大小只在第维可能不相同,在其他维都相同。拼接得到的张量在第维以外的维度保持不变,在第维的大小是各张量第维大小之和。

  代码清单3-10给出了用torch.cat() 函数拼接多个张量的例子。

代码清单3-10 使用torch.cat() 函数拼接多个张量

     tp = torch.arange(12).reshape(3, 4)

     tn = -tp

     tc0 = torch.cat([tp, tn], 0)

     print('tc0 = {}'.format(tc0))

     tc1 = torch.cat([tp, tp, tn, tn], 1)

     print('tc1 = {}'.format(tc1))

  拼接张量的另一种方法是使用torch.stack() 函数。这个函数同样有张量列表(或元组)和维度两个参数。torch.stack() 函数与torch.cat() 函数的不同之处在于,torch.stack() 函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度上的大小就是输入张量的个数。代码清单3-11给出了使用torch.stack() 拼接张量的例子。

代码清单3-11 使用torch.stack() 拼接张量

     tp = torch.arange(12).reshape(3, 4)

     tn = -tp

     ts0 = torch.stack([tp, tn], 0)

     print('ts0 = {}'.format(ts0))

     ts1 = torch.stack([tp, tp, tn, tn], 1)

     print('ts1 = {}'.format(ts1))


【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。