《神经网络与PyTorch实战》——3.3.3 张量的扩展和拼接
3.3.3 张量的扩展和拼接
本节介绍如何扩展张量,或将多个张量拼接成1个张量,会涉及torch.Tensor类的成员方法repeat()、函数torch.repeat()?和函数torch.cat()。
torch.Tensor类的成员方法repeat() 可以将张量的内容进行重复,使得张量的大小变大。对于大小为的张量,在使用成员方法repeat() 并给定参数时,可以得到大小为的张量,张量的元素就是将原来张量的元素复制份。
例如,在代码清单3-9中,首先构造了一个大小为 (1, 2) 的张量t12。接着通过t12.repeat() 得到大小为 (3, 4) 的张量t34。
代码清单3-9 使用repeat()?函数扩大张量
t12 = torch.tensor([[5., -9.],])
print('t12 = {}'.format(t12))
t34 = t12.repeat(3, 2)
print('t34 = {}'.format(t34))
torch.cat() 函数可以将多个张量拼接成1个张量。torch.cat() 函数有两个参数。第1个是一个张量的列表或是元组,这个函数就是要将这个列表或是元组里面的所有张量拼接起来。第2个参数指示要将这些张量在哪个维度拼接起来。如果要将这些张量在第维拼接起来,那么被拼接的张量的大小需要满足下列条件:
* 所有张量的维度都完全相同,并且大于。拼接得到的张量的维度也就是其中任意一个张量的维度。
* 所有张量的大小只在第维可能不相同,在其他维都相同。拼接得到的张量在第维以外的维度保持不变,在第维的大小是各张量第维大小之和。
代码清单3-10给出了用torch.cat() 函数拼接多个张量的例子。
代码清单3-10 使用torch.cat() 函数拼接多个张量
tp = torch.arange(12).reshape(3, 4)
tn = -tp
tc0 = torch.cat([tp, tn], 0)
print('tc0 = {}'.format(tc0))
tc1 = torch.cat([tp, tp, tn, tn], 1)
print('tc1 = {}'.format(tc1))
拼接张量的另一种方法是使用torch.stack() 函数。这个函数同样有张量列表(或元组)和维度两个参数。torch.stack() 函数与torch.cat() 函数的不同之处在于,torch.stack() 函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度上的大小就是输入张量的个数。代码清单3-11给出了使用torch.stack() 拼接张量的例子。
代码清单3-11 使用torch.stack() 拼接张量
tp = torch.arange(12).reshape(3, 4)
tn = -tp
ts0 = torch.stack([tp, tn], 0)
print('ts0 = {}'.format(ts0))
ts1 = torch.stack([tp, tp, tn, tn], 1)
print('ts1 = {}'.format(ts1))
- 点赞
- 收藏
- 关注作者
评论(0)