《神经网络与PyTorch实战》——3.4.5 比较和逻辑运算

举报
华章计算机 发表于 2019/06/05 20:16:29 2019/06/05
【摘要】 本书摘自《神经网络与PyTorch实战》——书中第3章,第3.4.5节,作者是肖智清。

3.4.5 比较和逻辑运算

  本节介绍逐元素比较张量大小的函数和逻辑函数。

  可以直接使用Python运算符(<、<=、>、>=、== 和 !=)逐元素比较两个张量的大小或张量和数值之间的大小。当运算符的左操作数和右操作数不是大小相同的张量时,会使用前文描述的广播语义。比较的结果张量的元素类型为torch.uint8。例如:

     loper = torch.tensor([-1, 1, 3], dtype=torch.float32)

     roper = torch.arange(3)

     print ('< : {}'.format(loper < roper))

     print ('== : {}'.format(loper == roper))

注意:在PyTorch中使用运算符“==”比较两个张量时,是进行逐元素比较,而不是将张量当作整体比较两个张量的所有元素是不是都相同。如果想把两个张量作为整体比较,可以用torch.equal() 函数。

  另外,torch.Tensor类还有成员方法nonzero(),相当于与0进行逐元素相等比较。

  还可以用torch.min() 函数和torch.max() 函数分别逐元素求两个张量的最小值和最大值,例如:

     torch.max(loper, roper)

注意:对于torch.min() 函数和torch.max() 函数,当它们只传入1个张量参数时,会分别试图统计张量内所有元素的最小值和最大值,是上一节介绍的统计函数。当它们传入两个张量参数时,则分别是逐元素选取最小值和最大值。

  接下来介绍张量的逻辑运算。这里只介绍torch.where() 函数。torch.where() 函数实现了逐元素if-else的功能。torch.where() 函数有3个参数:condition、x和y。当这3个参数是大小相同的张量时,该函数返回同样大小的张量。返回的张量里的元素是这样确定的:对于每个元素,考虑condition张量中对应元素的值是1(即表示真的值)还是0(即表示假的值)。如果是1,则选择张量x中对应的元素;如果是0,则选择张量y中对应的元素。例如,下列代码就根据传入的第0个参数的值,一次选择了x中的元素、y中的元素:

     cond = torch.tensor([1, 0, 1], dtype=torch.uint8)

     x = torch.tensor([0.3, -0.5, 0.2])

     y = torch.tensor([-0.2, 0.5, 0.3])

     torch.where(cond, x, y) # 得到 [0.3, 0.5, 0.2]


【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。