当前位置:网站首页>Pointnet++ code explanation (VII): pointnetsetabstractionmsg layer
Pointnet++ code explanation (VII): pointnetsetabstractionmsg layer
2022-07-19 05:44:00 【weixin_ forty-two million seven hundred and seven thousand and 】
The method of capturing multi-scale patterns is to apply different scale packet layers , Then extract the features of each scale according to the points . Connect the features of different scales , Form multi-scale features . Use Multi-Scale Grouping(MSG) Methodical SA layer :

Most of the forms are similar to ordinary SA Layers are similar , But here radius_list The input is a list for example [0.1,0.2,0.4], For different radii ball query, Finally, the point cloud features under different radii are saved in new_points_list in , And finally put it together . The specific code is as follows :
class PointNetSetAbstractionMsg(nn.Module):
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
super(PointNetSetAbstractionMsg, self).__init__()
'''
PointNet Set Abstraction (SA) module with Multi-Scale Grouping (MSG)
Input:
xyz: (batch_size, ndataset, 3) TF tensor
points: (batch_size, ndataset, channel) TF tensor
npoint: int32 -- #points sampled in farthest point sampling
radius_list: list of float32 -- search radius in local region
nsample_list: list of int32 -- how many points in each local region
mlp_list: list of list of int32 -- output size for MLP on each point
Return:
new_xyz: (batch_size, npoint, 3) TF tensor
new_points: (batch_size, npoint, sum_k{mlp[k][-1]}) TF tensor
'''
self.npoint = npoint
self.radius_list = radius_list
self.nsample_list = nsample_list
self.conv_blocks = nn.ModuleList()
self.bn_blocks = nn.ModuleList()
for i in range(len(mlp_list)):
convs = nn.ModuleList()
bns = nn.ModuleList()
last_channel = in_channel + 3
for out_channel in mlp_list[i]:
convs.append(nn.Conv2d(last_channel, out_channel, 1))
bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.conv_blocks.append(convs)
self.bn_blocks.append(bns)
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
B, N, C = xyz.shape
S = self.npoint
new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
new_points_list = []
for i, radius in enumerate(self.radius_list):
K = self.nsample_list[i]
group_idx = query_ball_point(radius, K, xyz, new_xyz)
grouped_xyz = index_points(xyz, group_idx)
grouped_xyz -= new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, group_idx)
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
else:
grouped_points = grouped_xyz
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
for j in range(len(self.conv_blocks[i])):
conv = self.conv_blocks[i][j]
bn = self.bn_blocks[i][j]
grouped_points = F.relu(bn(conv(grouped_points)))
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
new_points_list.append(new_points)
new_xyz = new_xyz.permute(0, 2, 1)
new_points_concat = torch.cat(new_points_list, dim=1)
return new_xyz, new_points_concat
边栏推荐
- Wechat applet password display hidden (small eyes)
- 【语音识别】kaldi安装心得
- 对Crowdhuman数据集处理,根据生成的train.txt分离数据集
- 尝试解决YOLOv5推理rtsp有延迟的一些方法
- CV学习笔记【2】:卷积与Conv2d
- Ambari 2.7.5 integrated installation hue 4.6
- gradle自定义插件
- JNI practical notes
- Pointnet++代码详解(六):PointNetSetAbstraction层
- 10. DWD layer construction of data warehouse construction
猜你喜欢
随机推荐
微信小程序的常用组件
Common components of wechat applet
使用OpenCV、ONNXRuntime部署YOLOV7目标检测——记录贴
SQL time comparison
1 SparkSQL概述
zTree自定义Title属性
Kotlin scope function
Use of MySQL
MySQL learning notes (4) - (basic crud) operate the data of tables in the database
微信小程序的自定义组件
Livedata analysis
Edge AI边缘智能:Communication-Efficient Edge AI: Algorithms and Systems(未完待续)
多模态融合方法总结
JNA加载DLL及在jar中的运用
微信小程序中的WXML模板语法
MySQL learning notes (5) -- join join table query, self join query, paging and sorting, sub query and nested query
Wxml template syntax in wechat applet
VS 中 error C4996: ‘scanf‘: This function or variable may be unsafe. 的解决方法。
PyTorch学习笔记【1】:使用张量表征真实数据
C language - bubble sort









