query_ball_point函数对应于Grouping layer, 这一层使用Ball query方法生成N'个局部区域,根据论文中的意思,这里有两个变量 ,一个是每个区域中点的数量K,另一个是球的半径。这里半径应该是占主导的,会在某个半径的球内找点,上限是K。球的半径和每个区域中点的数量都是人指定的。
query_ball_point函数用于寻找球形领域中的点。输入中radius为球形领域的半径,nsample为每个领域中要采样的点,new_xyz为S个球形领域的中心(由最远点采样在前面得出),xyz为所有的点云;输出为每个样本的每个球形领域的nsample个采样点集的索引[B,S,nsample]。
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3] ,s denotes the number of center points
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device)\
.view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius **2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
1、对于group_idx的理解:
group_idx = torch.arange(N, dtype=torch.long).to(device)\
.view(1, 1, N).repeat([B, S, 1])
N指的是一个样本中总的数据点的个数,用torch.arange(N)可以生成tensor([0,1,...,N-1]), 用.to(device)意思是说将生成的tensor([0,1,...,N-1])复制到的xyz所在的设备上,再用.view(1,1,N)则将tesor表示成tesnor([[[0,1,...,N-1]]])即有N列的意思,再用.repeat([B,S,1])则是说将原来的tensor在维度0上复制B个(原先只有1个),在维度1上复制S个,可以理解有B个batch,每个样本有S行N列,所以最后group_idx的维度为[B,S,N], 用代码来展示下:
import torch
N=5
B=3
S=2
group_idx0 = torch.arange(N, dtype=torch.long)
group_idx1=group_idx0.view(1, 1, N)
group_idx2=group_idx1.repeat([B, S, 1])
print("g0:",group_idx0)
print("g1:",group_idx1)
print("g2:",group_idx2)
g0: tensor([0, 1, 2, 3, 4])
g1: tensor([[[0, 1, 2, 3, 4]]])
g2: tensor([[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]]])
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
2、对group_idx.sort的理解:
torch.sort(input, dim=-1, descending=False, out=None),dim=-1说的是最后一维,在源码中指的就是dim=2
a=torch.randn(2,3,4)
print("a",a)
print("dim=0",torch.sort(a,0))
print("dim=1",torch.sort(a,1))
print("dim=2",torch.sort(a,2))
print("dim=-1",torch.sort(a,-1))
a tensor([[[ 0.1644, -0.9524, -0.0522, -1.7683],
[-0.0426, -1.3940, -0.9358, -2.5367],
[ 0.6171, 0.2587, 1.6798, 0.3828]],
[[ 1.0571, -0.2126, -0.1489, 0.5902],
[ 0.1673, -0.5937, -0.3240, 1.1439],
[-0.4273, -0.4449, -0.8735, -0.6969]]])
dim=0 (tensor([[[ 0.1644, -0.9524, -0.1489, -1.7683],
[-0.0426, -1.3940, -0.9358, -2.5367],
[-0.4273, -0.4449, -0.8735, -0.6969]],
[[ 1.0571, -0.2126, -0.0522, 0.5902],
[ 0.1673, -0.5937, -0.3240, 1.1439],
[ 0.6171, 0.2587, 1.6798, 0.3828]]]))
dim=1 (tensor([[[-0.0426, -1.3940, -0.9358, -2.5367],
[ 0.1644, -0.9524, -0.0522, -1.7683],
[ 0.6171, 0.2587, 1.6798, 0.3828]],
[[-0.4273, -0.5937, -0.8735, -0.6969],
[ 0.1673, -0.4449, -0.3240, 0.5902],
[ 1.0571, -0.2126, -0.1489, 1.1439]]])
dim=2 (tensor([[[-1.7683, -0.9524, -0.0522, 0.1644],
[-2.5367, -1.3940, -0.9358, -0.0426],
[ 0.2587, 0.3828, 0.6171, 1.6798]],
[[-0.2126, -0.1489, 0.5902, 1.0571],
[-0.5937, -0.3240, 0.1673, 1.1439],
[-0.8735, -0.6969, -0.4449, -0.4273]]])
dim=-1 (tensor([[[-1.7683, -0.9524, -0.0522, 0.1644],
[-2.5367, -1.3940, -0.9358, -0.0426],
[ 0.2587, 0.3828, 0.6171, 1.6798]],
[[-0.2126, -0.1489, 0.5902, 1.0571],
[-0.5937, -0.3240, 0.1673, 1.1439],
[-0.8735, -0.6969, -0.4449, -0.4273]]])
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
经过group_idx.sort(dim=-1)[0][:, :, :nsample]之后group_idx的维度为[B,S,nsample].
3、对group_idx[mask] = group_first[mask]的理解:
import torch
N=5
B=3
S=2
group_idx0 = torch.arange(N, dtype=torch.long)
group_idx1=group_idx0.view(1, 1, N)
group_idx2=group_idx1.repeat([B, S, 1])
mask= group_idx2 == 3
print(mask)
print(group_idx2[mask])
group_idx2[mask] =10
print(group_idx2)
maks: tensor([[[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0]],
[[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0]],
[[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0]]], dtype=torch.uint8)
group_idx2[mask]: tensor([3, 3, 3, 3, 3, 3])
group_idx2: tensor([[[ 0, 1, 2, 10, 4],
[ 0, 1, 2, 10, 4]],
[[ 0, 1, 2, 10, 4],
[ 0, 1, 2, 10, 4]],
[[ 0, 1, 2, 10, 4],
[ 0, 1, 2, 10, 4]]])
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
我们可以得出这样的结论: mask必须是一个 ByteTensor ,而且shape必须和 a一样 并且元素只能是0或者1 ,是将 mask中为1的元素所在的索引,在a中相同的的索引处替换为 value ,mask value必须同为tensor
文章来源: blog.csdn.net,作者:Studying_swz,版权归原作者所有,如需转载,请联系作者。
原文链接:blog.csdn.net/qq_37534947/article/details/117079778
评论(0)