pytorch多维筛选
【摘要】 多级筛选:
比如结构是2*2*3,只想选第三维的最大的
tx[index, best_n, g_y_center, g_x_center]
index=[01],best_n=[0,1]
最后只取两个值,第一行,第1列,第二行,第2列的。
筛选第3维最大的值,下面的代码不对,解决方法:查询max源码
也可以把3维用view降到2维再...
多级筛选:
比如结构是2*2*3,只想选第三维的最大的
tx[index, best_n, g_y_center, g_x_center]
index=[01],best_n=[0,1]
最后只取两个值,第一行,第1列,第二行,第2列的。
筛选第3维最大的值,下面的代码不对,解决方法:查询max源码
也可以把3维用view降到2维再计算就可以了。
import torch
anch_ious = torch.Tensor([[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]]])
print('b shape',anch_ious.shape)
b = torch.max(anch_ious, 2)
print(b[0])
print(b[1])
b = b[1].squeeze(1)
print(b)
print(anch_ious[list(range(anch_ious.size(0))),list(range(anch_ious.size(1))), b])
通过值筛选:
import torch
x = torch.linspace(1, 8, steps=8).view(4, 2)
#筛选第一维和第二维都>5.5的
print(x)
area=(x[:,0]>5.5)&(x[:,1]>5.5)
b=x[area]
# b= x[torch.where((x[:,0]&g
文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。
原文链接:blog.csdn.net/jacke121/article/details/86623589
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)