当前位置:网站首页>Pointnet++ code explanation (III): query_ ball_ Point function
Pointnet++ code explanation (III): query_ ball_ Point function
2022-07-19 05:44:00 【weixin_ forty-two million seven hundred and seven thousand and 】
query_ball_point The function corresponds to Grouping layer, This layer uses Ball query Method generation N' A local area , According to the meaning of the paper , There are two variables here , One is the number of points in each area K, The other is the radius of the ball . The radius here should be dominant , Will find a point in the ball with a certain radius , The upper limit is K. The radius of the ball and the number of midpoints in each area are specified .
query_ball_point The function is used to find points in the spherical field . Entering radius Is the radius of the sphere ,nsample For the points to be sampled in each field ,new_xyz by S The center of a spherical field ( From the farthest point sampling in front ),xyz For all point clouds ; Output for each spherical field of each sample nsample Index of a set of sampling points [B,S,nsample].
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3] ,s denotes the number of center points
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device)\
.view(1, 1, N).repeat([B, S, 1])
# sqrdists: [B, S, N] Record the Euclidean distance between the center point and all points
sqrdists = square_distance(new_xyz, xyz)
# Find all distances greater than radius^2 Of , Its group_idx Set directly to N; The rest keep the original value
group_idx[sqrdists > radius **2] = N
# In ascending order , The front is bigger than radius^2 The are N, Will be the maximum , So it will be taken out directly before the remaining points nsample A little bit
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
# Considering the possibility that nsample Some of the points are also assigned N The point of ( That is, there is insufficient... In the spherical area nsample A little bit ), This point needs to be abandoned , Just replace it with the first point
# group_first: [B, S, nsample], It's actually a group_idx Copy the value of the first point in to [B, S, nsample] Dimensions , Convenient for later replacement
# Here want to use view Because group_idx[:, :, 0] After taking out tensor Equivalent to two-dimensional Tensor, So we need to use view Into three dimensions tensor
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
# find group_idx The median is equal to N The point of , Will be output 0,1 Three dimensional structure Tensor, Dimension for [B,S,nsample]
mask = group_idx == N
# Replace the values of these points with the values of the first point
group_idx[mask] = group_first[mask]
return group_idx1、 about group_idx The understanding of the :
group_idx = torch.arange(N, dtype=torch.long).to(device)\
.view(1, 1, N).repeat([B, S, 1])N It refers to the total number of data points in a sample , use torch.arange(N) Can generate tensor([0,1,...,N-1]), use .to(device) It means to generate tensor([0,1,...,N-1]) Copy to xyz On the same device , Reuse .view(1,1,N) Will tesor Expressed as tesnor([[[0,1,...,N-1]]]) That is to say N Column means , Reuse .repeat([B,S,1]) It means that the original tensor In dimension 0 Copy on B individual ( Originally only 1 individual ), In dimension 1 Copy on S individual , Understandably, there are B individual batch, Each sample has S That's ok N Column , So finally group_idx The dimensions are [B,S,N], Use code to show :
import torch
N=5
B=3
S=2
group_idx0 = torch.arange(N, dtype=torch.long)
group_idx1=group_idx0.view(1, 1, N)
group_idx2=group_idx1.repeat([B, S, 1])
print("g0:",group_idx0)
print("g1:",group_idx1)
print("g2:",group_idx2)
# result :
g0: tensor([0, 1, 2, 3, 4])
g1: tensor([[[0, 1, 2, 3, 4]]])
g2: tensor([[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]]])2、 Yes group_idx.sort The understanding of the :
torch.sort(input, dim=-1, descending=False, out=None),dim=-1 It's the last dimension , In the source code, it means dim=2

a=torch.randn(2,3,4)
print("a",a)
print("dim=0",torch.sort(a,0))
print("dim=1",torch.sort(a,1))
print("dim=2",torch.sort(a,2))
print("dim=-1",torch.sort(a,-1))
# result
a tensor([[[ 0.1644, -0.9524, -0.0522, -1.7683],
[-0.0426, -1.3940, -0.9358, -2.5367],
[ 0.6171, 0.2587, 1.6798, 0.3828]],
[[ 1.0571, -0.2126, -0.1489, 0.5902],
[ 0.1673, -0.5937, -0.3240, 1.1439],
[-0.4273, -0.4449, -0.8735, -0.6969]]])
dim=0 (tensor([[[ 0.1644, -0.9524, -0.1489, -1.7683],
[-0.0426, -1.3940, -0.9358, -2.5367],
[-0.4273, -0.4449, -0.8735, -0.6969]],
[[ 1.0571, -0.2126, -0.0522, 0.5902],
[ 0.1673, -0.5937, -0.3240, 1.1439],
[ 0.6171, 0.2587, 1.6798, 0.3828]]]))
dim=1 (tensor([[[-0.0426, -1.3940, -0.9358, -2.5367],
[ 0.1644, -0.9524, -0.0522, -1.7683],
[ 0.6171, 0.2587, 1.6798, 0.3828]],
[[-0.4273, -0.5937, -0.8735, -0.6969],
[ 0.1673, -0.4449, -0.3240, 0.5902],
[ 1.0571, -0.2126, -0.1489, 1.1439]]])
dim=2 (tensor([[[-1.7683, -0.9524, -0.0522, 0.1644],
[-2.5367, -1.3940, -0.9358, -0.0426],
[ 0.2587, 0.3828, 0.6171, 1.6798]],
[[-0.2126, -0.1489, 0.5902, 1.0571],
[-0.5937, -0.3240, 0.1673, 1.1439],
[-0.8735, -0.6969, -0.4449, -0.4273]]])
dim=-1 (tensor([[[-1.7683, -0.9524, -0.0522, 0.1644],
[-2.5367, -1.3940, -0.9358, -0.0426],
[ 0.2587, 0.3828, 0.6171, 1.6798]],
[[-0.2126, -0.1489, 0.5902, 1.0571],
[-0.5937, -0.3240, 0.1673, 1.1439],
[-0.8735, -0.6969, -0.4449, -0.4273]]])after group_idx.sort(dim=-1)[0][:, :, :nsample] after group_idx The dimensions are [B,S,nsample].
3、 Yes group_idx[mask] = group_first[mask] The understanding of the :
import torch
N=5
B=3
S=2
group_idx0 = torch.arange(N, dtype=torch.long)
group_idx1=group_idx0.view(1, 1, N)
group_idx2=group_idx1.repeat([B, S, 1])
mask= group_idx2 == 3
print(mask)
print(group_idx2[mask])
group_idx2[mask] =10
print(group_idx2)
# result :
maks: tensor([[[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0]],
[[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0]],
[[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0]]], dtype=torch.uint8)
group_idx2[mask]: tensor([3, 3, 3, 3, 3, 3])
group_idx2: tensor([[[ 0, 1, 2, 10, 4],
[ 0, 1, 2, 10, 4]],
[[ 0, 1, 2, 10, 4],
[ 0, 1, 2, 10, 4]],
[[ 0, 1, 2, 10, 4],
[ 0, 1, 2, 10, 4]]])We can come to the conclusion that : mask Must be a ByteTensor , and shape It has to be with a equally And the element can only be 0 perhaps 1 , Yes, it will mask In Chinese, it means 1 The index of the element of the , stay a Replace the same index in with value ,mask value Must be both tensor
边栏推荐
- 电商用户行为实时分析系统(Flink1.10.1)
- Idea import local package
- DEEP JOINT TRANSMISSION-RECOGNITION FOR POWER-CONSTRAINED IOT DEVICES
- BottomSheetDialogFragment仿制抖音评论框
- OpenCV读取中文路径下的图片,并对其格式转化不改变颜色
- Pointnet++代码详解(四):index_points函数
- 用C语言实现猜数游戏
- Time difference calculation
- Using C language to realize guessing games
- Livedata analysis
猜你喜欢

Use ide to make jar package

模型时间复杂度和空间复杂度

CV-Model【2】:Alexnet

Wxml template syntax in wechat applet

电商用户行为实时分析系统(Flink1.10.1)

12. Ads layer construction of data warehouse construction

Solve idea new module prompt module XXXX does exits

Livedata analysis

Ambari cluster expansion node + expansion service operation

Scala初级实践——统计手机耗费流量(1)
随机推荐
JNA加载DLL及在jar中的运用
基于四叉树的图像压缩问题
The future of data Lakehouse - Open
1 sparksql overview
Object to map
OpenCV读取中文路径下的图片,并对其格式转化不改变颜色
7. Data warehouse environment preparation for data warehouse construction
1. Neusoft cross border e-commerce warehouse demand specification document
6. Data warehouse design for data warehouse construction
gradle
Use Flink SQL to transfer market data 1: transfer VWAP
C语言实现迭代实现二分查找
Using Flink SQL to fluidize market data 2: intraday var
Unable to determine Electron version. Please specify an Electron version
CV学习笔记【1】:transforms
gradle自定义插件
微信小程序的页面导航
[first launch in the whole network] will an abnormal main thread cause the JVM to exit?
8. ODS layer construction of data warehouse
微信小程序的常用组件