【Pytorch】torch.argmax(dim=1)用法
【摘要】 一、torch.argmax()(1)torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号;(2)dim给定的定义是:the demention to reduce.也就是把dim这个维度的,变成这个维度的最大值的index。 二、栗子# -*- coding: utf-8 -*-"""Created on Fri Jan 7 ...
一、torch.argmax()
(1)torch.argmax(input, dim=None, keepdim=False)
返回指定维度最大值的序号;
(2)dim
给定的定义是:the demention to reduce.也就是把dim
这个维度的,变成这个维度的最大值的index。
二、栗子
# -*- coding: utf-8 -*-
"""
Created on Fri Jan 7 15:05:09 2022
@author: 86493
"""
import torch
a=torch.tensor([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
b=torch.argmax(a,dim=1)
print(a)
print(a.shape)
print(b)
(1)这个例子,tensor(2, 3, 4)
,因为是dim=1
,即将第二维度去掉,变成tensor(2, 4)
,将每一个3x4数组,变成1x4数组。
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
如上所示的3×4矩阵,==取每一列的最大值==对应的下标,a[0]中第一列的最大值的行标为1, 第二列的最大值的行标为2,第三列的最大值行标为0,第4列的最大值行标为1,所以最后输出[1, 2, 0, 1],取每一列的最大值,结果为:
tensor([[[ 1, 5, 5, 2],
[ 9, -6, 2, 8],
[-3, 7, -9, 1]],
[[-1, 7, -5, 2],
[ 9, 6, 2, 8],
[ 3, 7, 9, 1]]])
torch.Size([2, 3, 4])
tensor([[1, 2, 0, 1],
[1, 0, 2, 1]])
(1)如果改成dim=2
,即将第三维去掉,==即取每一行的最大值==对应的下标,结果为tensor(2, 3)
。
import torch
a=torch.tensor([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
b=torch.argmax(a,dim=2)
print(b)
print(a.shape)
"""
tensor([[2, 0, 1],
[1, 0, 2]])
torch.Size([2, 3, 4])
"""
【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)