Pytorch 对比expand和repeat函数

举报
风吹稻花香 发表于 2022/04/10 22:34:58 2022/04/10
【摘要】 torch.Tensor有两个实例方法可以用来扩展某维的数据的尺寸,分别是 repeat()和 expand()。 expand和repeat函数是pytorch中常用于进行张量数据复制和维度扩展的函数,但其工作机制差别很大,本文对这两个函数进行对比。 1. expand tensor.expand(*sizes)expand函...

torch.Tensor有两个实例方法可以用来扩展某维的数据的尺寸,分别是 repeat()expand()

expand和repeat函数是pytorch中常用于进行张量数据复制和维度扩展的函数,但其工作机制差别很大,本文对这两个函数进行对比。

1. expand
tensor.expand(*sizes)

expand函数用于将张量中单数维的数据扩展到指定的size。

首先解释下什么叫单数维(singleton dimensions),张量在某个维度上的size为1,则称为单数维。比如zeros(2,3,4)不存在单数维,而zeros(2,1,4)在第二个维度(即维度1)上为单数维。expand函数仅仅能作用于这些单数维的维度上。

参数*sizes用于逐个指定各个维度扩展后的大小(也可以理解为拓展的次数),对于不需要或者无法(即非单数维)进行扩展的维度,对应位置可写上原始维度大小或直接写作-1。

expand函数可能导致原始张量的升维,其作用在张量前面的维度上,因此通过expand函数可将张量数据复制多份(可理解为沿着第一个batch的维度上)。

另一个值得注意的点是:expand函数并不会重新分配内存,返回结果仅仅是原始张量上的一个视图。

下面为几个简单的示例:


  
  1. import torch
  2. a = tensor([1, 0, 2])
  3. b = a.expand(2, -1)   # 第一个维度为升维,第二个维度保持原阳
  4. # b为   tensor([[1, 0, 2],  [1, 0, 2]])
  5. a = torch.tensor([[1], [0], [2]])
  6. b = a.expand(-1, 2)   # 保持第一个维度,第二个维度只有一个元素,可扩展
  7. # b为  tensor([[1, 1],
  8. #              [0, 0],
  9. #              [2, 2]])


2. expand_as
expand_as函数可视为expand的另一种表达,其size通过函数传递的目标张量的size来定义。


  
  1. import torch
  2. a = torch.tensor([1, 0, 2])
  3. b = torch.zeros(2, 3)
  4. c = a.expand_as(b)  # a照着b的维度大小进行拓展
  5. # c为 tensor([[1, 0, 2],
  6. #        [1, 0, 2]])



3. repeat
前文提及expand仅能作用于单数维,那对于非单数维的拓展,那就需要借助于repeat函数了。

tensor.repeat(*sizes)
1
参数*sizes指定了原始张量在各维度上复制的次数。整个原始张量作为一个整体进行复制,这与Numpy中的repeat函数截然不同,而更接近于tile函数的效果。

与expand不同,repeat函数会真正的复制数据并存放于内存中。

下面是一个简单的例子:


  
  1. import torch
  2. x = torch.tensor([[1, 2, 3],[1, 2, 3]])
  3. x = torch.tensor([1, 2, 3])
  4. print(x.size())
  5. print(x)
  6. y = x.repeat(2, 2)
  7. print(y)

结果:

torch.Size([3])
tensor([1, 2, 3])
tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])


4. repeat_intertile
Pytorch中,与Numpy的repeat函数相类似的函数为torch.repeat_interleave:

torch.repeat_interleave(input, repeats, dim=None)
1
参数input为原始张量,repeats为指定轴上的复制次数,而dim为复制的操作轴,若取值为None则默认将所有元素进行复制,并会返回一个flatten之后一维张量。

与repeat将整个原始张量作为整体不同,repeat_interleave操作是逐元素的。

下面是一个简单的例子:


  
  1. a = torch.tensor([[1], [0], [2]])
  2. b = torch.repeat_interleave(a, repeats=3)   # 结果flatten
  3. # b为tensor([1, 1, 1, 0, 0, 0, 2, 2, 2])
  4. c = torch.repeat_interleave(a, repeats=3, dim=1)  # 沿着axis=1逐元素复制
  5. # c为tensor([[1, 1, 1],
  6. #        [0, 0, 0],
  7. #        [2, 2, 2]])


原文链接:https://blog.csdn.net/guofei_fly/article/details/104467138

expand函数:

ok的:


  
  1. import torch
  2. x = torch.tensor([[1, 2, 3]])
  3. print(x.size())
  4. y = x.expand(2, 3)
  5. print(y)
  6. x = torch.tensor([[1, 2, 3],[1, 2, 3]])
  7. print(x.size())
  8. y = x.expand(2, 3)
  9. print(y)

异常的:


  
  1. import torch
  2. x = torch.tensor([[1, 2, 3],[1, 2, 3]])
  3. print(x.size())
  4. y = x.expand(4, 3)
  5. print(y)

repeat例子:

repeat的参数必须和data维度相同:


  
  1. import torch
  2. data = torch.tensor([[[1, 2, 3],[4, 5, 6],[7, 8, 9]]])
  3. # x = torch.tensor([1, 2, 3])
  4. print(data.size())
  5. print(data)
  6. y = data.repeat(2,1,1)
  7. print(y.size())
  8. print(y)

文章来源: blog.csdn.net,作者:AI视觉网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/124069005

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。