《神经网络与PyTorch实战》——3.3.2 选取部分张量元素

举报
华章计算机 发表于 2019/06/05 20:02:07 2019/06/05
【摘要】 本书摘自《神经网络与PyTorch实战》——书中第3章,第3.3.2节,作者是肖智清。

3.3.2 选取部分张量元素

  本节介绍如何选取一个张量中的部分元素,会涉及torch.Tensor类的成员方法index_select()、masked_select() 和take(),并介绍如何利用英文方括号“[]”取元素。

  首先来看index_select()。这个成员方法有两个参数:参数dim是一个int值,它表示对哪个维度进行处理;参数index是一个torch.Tensor类实例,表示选取那个维度的哪些指标。如果调用index_select() 的torch.Tensor类实例的大小为,参数dim的值为,参数index的大小为,那么选取得到的张量大小为

。代码清单3-6给出了一个使用index_select() 选择元素的例子。

代码清单3-6 index_select() 用法举例

     t = torch.arange(24).reshape(2, 3, 4) # 大小 = (2, 3, 4)

     index = torch.tensor([1, 2]) # 大小 (2,)

     t.index_select(1, index) # 大小 (2, 2, 4)

  要对张量的某个维度进行选取,除了使用成员函数index_select() 外,还可以使用英文方括号“[]”。与index_select() 不同,如果要使用英文方括号,需要给出在所有维度上的指标。在每个维度上,可以选取1个指标、多个指标(包括全部指标)。

  首先来看看如何给出某个维度上的指标。如果要选取某个指标,直接写出指标数即可。如果某一维度的大小为,则指标数可以是这样的非负整数,也可以是这样的负数。如果是非负数,指标指示从开始往结尾顺序数第几个元素(从0开始);如果是负数,指标的绝对值指示从结尾到开头逆序数第几个元素(从1开始)。还可以利用英文冒号“:”得到连续的指标:冒号之前是开始指标(含),冒号之后是结束指标(不含)。如果没有写开始指标,则默认从0开始;如果没有写结束指标,则默认到最末尾。在开始指标和结束指标之后,还可以再用一个英文冒号来指定一串不连续的指标。第2个英文冒号后表示从多少个元素中选取1个元素。代码清单3-7演示了如何选取某个维度上的指标。

代码清单3-7 选取某一维度的元素举例

     t = torch.arange(12)

     print(t[3]) # 选取1个元素,大小(),值为3

     print(t[-5]) # 选取1个元素,大小(),值为7

     print(t[3:6]) # 选取连续元素,大小(3,)

     print(t[:6]) # 选取连续元素,大小(6,)

     print(t[3:]) # 选取连续元素,大小(9,)

     print(t[-5:]) # 选取连续元素,大小(5,)

     print(t[3:6:2]) # 选取不连续元素,大小(2,)

     print(t[3::2]) # 选取不连续元素,大小(5,)

  对于有多个维度的情况,可以用英文逗号“,”分隔各维度的指标。代码清单3-8给出了从多个维度选取元素的示例。

代码清单3-8 从多个维度选取元素

     t = torch.arange(12).reshape(3, 4)

     print(t[2:,-2]) # 大小(1,),值为[10,]

     print(t[0,:]) # 大小(4,),值为[0,1,2,3]

  接下来介绍成员方法masked_select()。这个方法接受1个张量参数mask,它的大小必须和调用masked_selected() 方法的类实例相同,并且元素类型必须为torch.uint8。张量mask里面的元素非0即1,表示是否要选择对应元素。masked_select() 将用张量mask选定的那些值以一个一维张量的形式返回,一维张量的元素个数就是张量mask中1的个数。例如:

     t = torch.arange(12).reshape(3, 4)

     mask = torch.tensor([[1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 0]], dtype=torch.uint8)

     t.masked_select(mask) # 大小 (5,)

  最后,来看torch.Tensor类的成员take()。take() 函数不再考虑张量的具体大小,而只考虑张量的元素总个数。take() 函数将张量各元素按照唯一的指标进行索引,相当于对经过reshape(-1) 操作后的张量进行索引。例如:

     t = torch.arange(12).reshape(3, 4)

     indices = torch.tensor([2, 5, 6])

     t.take(indices) # 大小(3,),值为[2,5,6]


【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。