深度实践OpenStack:基于Python的OpenStack组件开发—3.3.2 选取部分张量元素
3.3.2 选取部分张量元素
本节介绍如何选取一个张量中的部分元素,会涉及torch.Tensor类的成员方法index_select()、masked_select() 和take(),并介绍如何利用英文方括号“[]”取元素。
首先来看index_select()。这个成员方法有两个参数:参数dim是一个int值,它表示对哪个维度进行处理;参数index是一个torch.Tensor类实例,表示选取那个维度的哪些指标。如果调用index_select() 的torch.Tensor类实例的大小为,参数dim的值为,参数index的大小为,那么选取得到的张量大小为
。代码清单3-6给出了一个使用index_select() 选择元素的例子。
代码清单3-6 index_select() 用法举例
t = torch.arange(24).reshape(2, 3, 4) # 大小 = (2, 3, 4)
index = torch.tensor([1, 2]) # 大小 (2,)
t.index_select(1, index) # 大小 (2, 2, 4)
要对张量的某个维度进行选取,除了使用成员函数index_select() 外,还可以使用英文方括号“[]”。与index_select() 不同,如果要使用英文方括号,需要给出在所有维度上的指标。在每个维度上,可以选取1个指标、多个指标(包括全部指标)。
首先来看看如何给出某个维度上的指标。如果要选取某个指标,直接写出指标数即可。如果某一维度的大小为,则指标数可以是这样的非负整数,也可以是这样的负数。如果是非负数,指标指示从开始往结尾顺序数第几个元素(从0开始);如果是负数,指标的绝对值指示从结尾到开头逆序数第几个元素(从1开始)。还可以利用英文冒号“:”得到连续的指标:冒号之前是开始指标(含),冒号之后是结束指标(不含)。如果没有写开始指标,则默认从0开始;如果没有写结束指标,则默认到最末尾。在开始指标和结束指标之后,还可以再用一个英文冒号来指定一串不连续的指标。第2个英文冒号后表示从多少个元素中选取1个元素。代码清单3-7演示了如何选取某个维度上的指标。
代码清单3-7 选取某一维度的元素举例
t = torch.arange(12)
print(t[3]) # 选取1个元素,大小(),值为3
print(t[-5]) # 选取1个元素,大小(),值为7
print(t[3:6]) # 选取连续元素,大小(3,)
print(t[:6]) # 选取连续元素,大小(6,)
print(t[3:]) # 选取连续元素,大小(9,)
print(t[-5:]) # 选取连续元素,大小(5,)
print(t[3:6:2]) # 选取不连续元素,大小(2,)
print(t[3::2]) # 选取不连续元素,大小(5,)
对于有多个维度的情况,可以用英文逗号“,”分隔各维度的指标。代码清单3-8给出了从多个维度选取元素的示例。
代码清单3-8 从多个维度选取元素
t = torch.arange(12).reshape(3, 4)
print(t[2:,-2]) # 大小(1,),值为[10,]
print(t[0,:]) # 大小(4,),值为[0,1,2,3]
接下来介绍成员方法masked_select()。这个方法接受1个张量参数mask,它的大小必须和调用masked_selected() 方法的类实例相同,并且元素类型必须为torch.uint8。张量mask里面的元素非0即1,表示是否要选择对应元素。masked_select() 将用张量mask选定的那些值以一个一维张量的形式返回,一维张量的元素个数就是张量mask中1的个数。例如:
t = torch.arange(12).reshape(3, 4)
mask = torch.tensor([[1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 0]], dtype=torch.uint8)
t.masked_select(mask) # 大小 (5,)
最后,来看torch.Tensor类的成员take()。take() 函数不再考虑张量的具体大小,而只考虑张量的元素总个数。take() 函数将张量各元素按照唯一的指标进行索引,相当于对经过reshape(-1) 操作后的张量进行索引。例如:
t = torch.arange(12).reshape(3, 4)
indices = torch.tensor([2, 5, 6])
t.take(indices) # 大小(3,),值为[2,5,6]
- 点赞
- 收藏
- 关注作者
评论(0)