当前位置:网站首页>Pointnet++代码详解(一):farthest_point_sample函数
Pointnet++代码详解(一):farthest_point_sample函数
2022-07-17 05:10:00 【weixin_42707080】
初入Pointnet++,看相关源码感觉很费力,想着把自己学到的记下来,避免后面忘记要用到又得重新思考,本系列主要讲解Pointnet++代码,其理论部分大家可以在网上自行搜索相关资料。本系列分析的源码来自:https://github.com/yanx27/Pointnet_Pointnet2_pytorch
farthest_point_sample函数是来自于Pointnet++的FPS(Farthest Point Sampling) 最远点采样法,该方法比随机采样的优势在于它可以尽可能的覆盖空间中的所有点。
最远点采样是Set Abstraction模块中较为核心的步骤,其目的是从一个输入点云中按照所需要的点的个数npoint采样出足够多的点,并且点与点之间的距离要足够远。最后的返回结果是npoint个采样点在原始点云中的索引。
FPS的逻辑如下:
假设一共有n个点,整个点集为N = {f1, f2,…,fn}, 目标是选取n1个起始点做为下一步的中心点:
- 随机选取一个点fi为起始点,并写入起始点集 B = {fi};
- 选取剩余n-1个点计算和fi点的距离,选择最远点fj写入起始点集B={fi,fj};
- 选取剩余n-2个点计算和点集B中每个点的距离, 将最短的那个距离作为该点到点集的距离, 这样得到n-2个到点集的距离,选取最远的那个点写入起始点B = {fi, fj ,fk},同时剩下n-3个点, 如果n1=3 则到此选择完毕;
- 如果n1 > 3则重复上面步骤直到选取n1个起始点为止.
具体实现步骤如下:
- 先随机初始化一个centroids矩阵,后面用于存储npoint个采样点的索引位置,大小为B×npoint,其中B为BatchSize的个数,即B个样本;
- 利用distance矩阵记录某个样本中所有点到某一个点的距离,初始化为B×N矩阵,初值给个比较大的值,后面会迭代更新;
- 利用farthest表示当前最远的点,也是随机初始化,范围为0~N,初始化B个,对应到每个样本都随机有一个初始最远点;
- batch_indices初始化为0~(B-1)的数组;
- 直到采样点达到npoint,否则进行如下迭代:
- (1)设当前的采样点centroids为当前的最远点farthest;
- (2)取出这个中心点centroid的坐标;
- (3)求出所有点到这个farthest点的欧式距离,存在dist矩阵中;
- (4) 建立一个mask,如果dist中的元素小于distance矩阵中保存的距离值,则更新distance中的对应值,随着迭代的继续distance矩阵中的值会慢慢变小,其相当于记录着某个样本中每个点距离所有已出现的采样点的最小距离;
- (5)最后从distance矩阵取出最远的点为farthest,继续下一轮迭代.
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
batchsize, ndataset, dimension = xyz.shape
#to方法Tensors和Modules可用于容易地将对象移动到不同的设备(代替以前的cpu()或cuda()方法)
# 如果他们已经在目标设备上则不会执行复制操作
centroids = torch.zeros(batchsize, npoint, dtype=torch.long).to(device)
distance = torch.ones(batchsize, ndataset).to(device) * 1e10
#randint(low, high, size, dtype)
# torch.randint(3, 5, (3,))->tensor([4, 3, 4])
farthest = torch.randint(0, ndataset, (batchsize,), dtype=torch.long).to(device)
#batch_indices=[0,1,...,batchsize-1]
batch_indices = torch.arange(batchsize, dtype=torch.long).to(device)
for i in range(npoint):
# 更新第i个最远点
centroids[:,i] = farthest
# 取出这个最远点的xyz坐标
centroid = xyz[batch_indices, farthest, :].view(batchsize, 1, 3)
# 计算点集中的所有点到这个最远点的欧式距离
#等价于torch.sum((xyz - centroid) ** 2, 2)
dist = torch.sum((xyz - centroid) ** 2, -1)
# 更新distances,记录样本中每个点距离所有已出现的采样点的最小距离
mask = dist < distance
distance[mask] = dist[mask]
# 从更新后的distances矩阵中找出距离最远的点,作为最远点用于下一轮迭代
#取出每一行的最大值构成列向量,等价于torch.max(x,2)
farthest = torch.max(distance, -1)[1]
return centroids1、xyz是点云的坐标数据,其维度为[B,N,3], B代表Batchsize,即有多少样本, N代表每个样本的总点数,3代表点云的x,y,z坐标;
npoint代表采样点数,centroids代表采样点的索引,其维度为[B, N]。
2、关于device

device = xyz.device因此,这句代码说的就是将xyz的device属性赋给device,这是为了后续操作所采用的。
3、shape

可以看出shape与size()是一样的,而dim()返回的是Tensor的维度(秩)
4、to(device)

centroids = torch.zeros(batchsize, npoint, dtype=torch.long).to(device)
distance = torch.ones(batchsize, ndataset).to(device) * 1e10to方法Tensors和Modules可用于容易地将对象移动到不同的设备(代替以前的cpu()或cuda()方法)
注意:如果数据已经在目标设备上则不会执行复制操作
5、torch.randint和torch.arange
torch.randint(low=0, high, size):size是元组,产生从low到high之间的随机整数,大小为size。

torch.arange(start, end, step) # 不包括end, step是两个点间距,start默认为0,step默认为1

#randint(low, high, size, dtype)
# torch.randint(3, 5, (3,))->tensor([4, 3, 4])
farthest = torch.randint(0, ndataset, (batchsize,), dtype=torch.long).to(device)
#batch_indices=[0,1,...,batchsize-1]
batch_indices = torch.arange(batchsize, dtype=torch.long).to(device)常用函数:https://www.jianshu.com/p/46a8ad87d238
6、
for i in range(npoint):
# 更新第i个最远点,centroids:[B,npoint],farthest是最远点的索引
centroids[:,i] = farthest
# 取出batchsize的每个样本这个最远点的xyz坐标,xyz:[B,N,3]
centroid = xyz[batch_indices, farthest, :].view(batchsize, 1, 3)
# 计算点集中的所有点到这个最远点的欧式距离
#等价于torch.sum((xyz - centroid) ** 2, 2)
dist = torch.sum((xyz - centroid) ** 2, -1)
# 更新distances,记录样本中每个点距离所有已出现的采样点的最小距离
mask = dist < distance
distance[mask] = dist[mask]
# 从更新后的distances矩阵中找出距离最远的点,作为最远点用于下一轮迭代
#torch.max(distance, -1)取出每一行的最大值构成列向量,等价于torch.max(x,2)
#torch.max(distance, -1)[1]是取列向量的索引,若torch.max(distance, -1)[0]则是取列向量
farthest = torch.max(distance, -1)[1]torch.sum(input, dim, out=None) → Tensor
- input (Tensor) – 输入张量
- dim (int) – 缩减的维度
- out (Tensor, optional) – 结果张量
import torch
x = torch.randn(4, 5)
print(x)
print(x.sum(0)) #按列求和
print(x.sum(1)) #按行求和
print(torch.sum(x)) #按列求和
print(torch.sum(x, 0))#按列求和
print(torch.sum(x, 1))#按行求和
#结果:
tensor([[ 0.2210, 1.8035, 0.7671, -0.1836, -0.2794],
[-0.7922, -1.0881, -2.0180, 1.0981, 0.2320],
[-0.4681, 0.1820, 0.0502, 0.0067, 1.3218],
[ 0.4785, 1.0799, 1.6197, 0.6642, 0.6915]])
tensor([-0.5608, 1.9773, 0.4190, 1.5854, 1.9660])
tensor([ 2.3287, -2.5682, 1.0926, 4.5338])
tensor(5.3868)
tensor([-0.5608, 1.9773, 0.4190, 1.5854, 1.9660])
tensor([ 2.3287, -2.5682, 1.0926, 4.5338])
对于三维而言,
import torch
xyz = torch.tensor([[[3,7,9],[10,5,2]],[[5,4,2],[1,6,9]]])
dist0 = torch.sum(xyz, -1)
dist1 = torch.sum(xyz, 2)
dist2 = torch.sum(xyz, 1)
dist3 = torch.sum(xyz)
print("xyz:",xyz)
print("sum-1:",dist0)
print("sum2:", dist1)
print("sum1:",dist2)
print("sum:", dist3)
结果:
xyz: tensor([[[ 3, 7, 9],
[10, 5, 2]],
[[ 5, 4, 2],
[ 1, 6, 9]]])
sum-1: tensor([[19, 17],
[11, 16]])
sum2: tensor([[19, 17],
[11, 16]])
sum1: tensor([[13, 12, 11],
[ 6, 10, 11]])
sum: tensor(63)
更多sum用法详见:https://blog.csdn.net/qq_39463274/article/details/105145029
torch.max:
对于tensorA和tensorB:
- torch.max(tensorA):返回tensor中的最大值。
- torch.max(tensorA,dim):dim表示指定的维度,返回指定维度的最大数和对应下标
- torch.max(tensorA,tensorB):比较tensorA和tensorB相对较大的元素。

若为三阶张量,则结果如下:
import torch
x= torch.tensor([[[3,7,9],[10,5,2]],[[5,4,2],[1,6,9]]])
k0=torch.max(x,0)
k1=torch.max(x,1)
k2=torch.max(x,2)
k3=torch.max(x,-1)
print("x:",x)
print("k0:",k0)
print("k1:",k1)
print("k2:",k2)
print("k-1:",k3)
结果:
x: tensor([[[ 3, 7, 9],
[10, 5, 2]],
[[ 5, 4, 2],
[ 1, 6, 9]]])
k0: (tensor([[ 5, 7, 9],
[10, 6, 9]]), tensor([[1, 0, 0],
[0, 1, 1]]))
k1: (tensor([[10, 7, 9],
[ 5, 6, 9]]), tensor([[1, 0, 0],
[0, 1, 1]]))
k2: (tensor([[ 9, 10],
[ 5, 9]]), tensor([[2, 0],
[0, 2]]))
k-1: (tensor([[ 9, 10],
[ 5, 9]]), tensor([[2, 0],
[0, 2]]))
详细请见:https://blog.csdn.net/Linux_bin/article/details/95599849
边栏推荐
- Ambari2.7.5 integration es6.4.2
- MySQL 查询当天、本周,本月、上一个月的数据
- [first launch in the whole network] will an abnormal main thread cause the JVM to exit?
- 使用Gson解析错误json数据
- Calculator of wechat applet
- replace限制文本框只能输入数字,数字和字母等的正则表达式
- Pointer array & array pointer
- Parsing bad JSON data using gson
- 【语音识别入门】基础概念与框架
- Configure tabbar and request network data requests
猜你喜欢

MySQL learning notes (5) -- join join table query, self join query, paging and sorting, sub query and nested query

Unable to determine Electron version. Please specify an Electron version

Common components of wechat applet

4.东软跨境电商数仓项目--数据采集通道搭建之用户行为数据采集通道搭建(2022.6.1-2022.6.4)

C language & bit field

电商用户行为实时分析系统(Flink1.10.1)

The future of data Lakehouse - Open

微信小程序代码的构成

10.数据仓库搭建之DWD层搭建

1. Dongsoft Cross - Border E - commerce Data Warehouse Requirement specification document
随机推荐
The future of data Lakehouse - Open
throttle/debounce应用及原理
【语音识别】kaldi安装心得
MySQL learning notes (4) - (basic crud) operate the data of tables in the database
4.东软跨境电商数仓项目--数据采集通道搭建之用户行为数据采集通道搭建(2022.6.1-2022.6.4)
微信小程序密码显示隐藏(小眼睛)
5.数据采集通道搭建之业务分析
关于Kotlin泛型遇到的问题
widerperson数据集转化为YOLOv5训练格式,并加入到crowdhuman中
Common (Consortium)
12.数据仓库搭建之ADS层搭建
Parsing bad JSON data using gson
Pointer array & array pointer
How can the thread pool be monitored to help developers quickly locate online errors?
MySQL queries the data of the current day, this week, this month and last month
8.数据仓库之ODS层搭建
5.1数据采集通道搭建之业务数据采集通道搭建
Syntax differences between PgSQL and Oracle (SQL migration records)
Macro definition of C language
Custom components of wechat applet