《神经网络与PyTorch实战》——3.4.5 比较和逻辑运算
3.4.5 比较和逻辑运算
本节介绍逐元素比较张量大小的函数和逻辑函数。
可以直接使用Python运算符(<、<=、>、>=、== 和 !=)逐元素比较两个张量的大小或张量和数值之间的大小。当运算符的左操作数和右操作数不是大小相同的张量时,会使用前文描述的广播语义。比较的结果张量的元素类型为torch.uint8。例如:
loper = torch.tensor([-1, 1, 3], dtype=torch.float32)
roper = torch.arange(3)
print ('< : {}'.format(loper < roper))
print ('== : {}'.format(loper == roper))
注意:在PyTorch中使用运算符“==”比较两个张量时,是进行逐元素比较,而不是将张量当作整体比较两个张量的所有元素是不是都相同。如果想把两个张量作为整体比较,可以用torch.equal() 函数。
另外,torch.Tensor类还有成员方法nonzero(),相当于与0进行逐元素相等比较。
还可以用torch.min() 函数和torch.max() 函数分别逐元素求两个张量的最小值和最大值,例如:
torch.max(loper, roper)
注意:对于torch.min() 函数和torch.max() 函数,当它们只传入1个张量参数时,会分别试图统计张量内所有元素的最小值和最大值,是上一节介绍的统计函数。当它们传入两个张量参数时,则分别是逐元素选取最小值和最大值。
接下来介绍张量的逻辑运算。这里只介绍torch.where() 函数。torch.where() 函数实现了逐元素if-else的功能。torch.where() 函数有3个参数:condition、x和y。当这3个参数是大小相同的张量时,该函数返回同样大小的张量。返回的张量里的元素是这样确定的:对于每个元素,考虑condition张量中对应元素的值是1(即表示真的值)还是0(即表示假的值)。如果是1,则选择张量x中对应的元素;如果是0,则选择张量y中对应的元素。例如,下列代码就根据传入的第0个参数的值,一次选择了x中的元素、y中的元素:
cond = torch.tensor([1, 0, 1], dtype=torch.uint8)
x = torch.tensor([0.3, -0.5, 0.2])
y = torch.tensor([-0.2, 0.5, 0.3])
torch.where(cond, x, y) # 得到 [0.3, 0.5, 0.2]
- 点赞
- 收藏
- 关注作者
评论(0)