in wypr/modeling/backbone/pointnet2/pointnet2_modules.py [0:0]
def forward(self, xyz: torch.Tensor,
features: torch.Tensor = None,
inds: torch.Tensor = None) -> (torch.Tensor, torch.Tensor):
r"""
Parameters
----------
xyz : torch.Tensor
(B, N, 3) tensor of the xyz coordinates of the features
features : torch.Tensor
(B, C, N) tensor of the descriptors of the the features
inds : torch.Tensor
(B, npoint) tensor that stores index to the xyz points (values in 0-N-1)
Returns
-------
new_xyz : torch.Tensor
(B, npoint, 3) tensor of the new features' xyz
new_features : torch.Tensor
(B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
inds: torch.Tensor
(B, npoint) tensor of the inds
"""
xyz_flipped = xyz.transpose(1, 2).contiguous()
if inds is None:
inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint)
else:
assert(inds.shape[1] == self.npoint)
new_xyz = pointnet2_utils.gather_operation(
xyz_flipped, inds
).transpose(1, 2).contiguous() if self.npoint is not None else None
if not self.ret_unique_cnt:
grouped_features, grouped_xyz = self.grouper(
xyz, new_xyz, features
) # (B, C, npoint, nsample)
else:
grouped_features, grouped_xyz, unique_cnt = self.grouper(
xyz, new_xyz, features
) # (B, C, npoint, nsample), (B,3,npoint,nsample), (B,npoint)
new_features = self.mlp_module(
grouped_features
) # (B, mlp[-1], npoint, nsample)
if self.pooling == 'max':
new_features = F.max_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
) # (B, mlp[-1], npoint, 1)
elif self.pooling == 'avg':
new_features = F.avg_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
) # (B, mlp[-1], npoint, 1)
elif self.pooling == 'rbf':
# Use radial basis function kernel for weighted sum of features (normalized by nsample and sigma)
# Ref: https://en.wikipedia.org/wiki/Radial_basis_function_kernel
rbf = torch.exp(-1 * grouped_xyz.pow(2).sum(1,keepdim=False) / (self.sigma**2) / 2) # (B, npoint, nsample)
new_features = torch.sum(new_features * rbf.unsqueeze(1), -1, keepdim=True) / float(self.nsample) # (B, mlp[-1], npoint, 1)
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
if not self.ret_unique_cnt:
return new_xyz, new_features, inds
else:
return new_xyz, new_features, inds, unique_cnt