in wypr/modeling/backbone/pointnet2/pointnet2_utils.py [0:0]
def forward(self, xyz, new_xyz, features=None):
# type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
r"""
Parameters
----------
xyz : torch.Tensor
xyz coordinates of the features (B, N, 3)
new_xyz : torch.Tensor
centriods (B, npoint, 3)
features : torch.Tensor
Descriptors of the features (B, C, N)
Returns
-------
new_features : torch.Tensor
(B, 3 + C, npoint, nsample) tensor
"""
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
if self.sample_uniformly:
unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
for i_batch in range(idx.shape[0]):
for i_region in range(idx.shape[1]):
unique_ind = torch.unique(idx[i_batch, i_region, :])
num_unique = unique_ind.shape[0]
unique_cnt[i_batch, i_region] = num_unique
sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long)
all_ind = torch.cat((unique_ind, unique_ind[sample_ind]))
idx[i_batch, i_region, :] = all_ind
xyz_trans = xyz.transpose(1, 2).contiguous()
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
if self.normalize_xyz:
grouped_xyz /= self.radius
if features is not None:
grouped_features = grouping_operation(features, idx)
if self.use_xyz:
new_features = torch.cat(
[grouped_xyz, grouped_features], dim=1
) # (B, C + 3, npoint, nsample)
else:
new_features = grouped_features
else:
assert (
self.use_xyz
), "Cannot have not features and not use xyz as a feature!"
new_features = grouped_xyz
ret = [new_features]
if self.ret_grouped_xyz:
ret.append(grouped_xyz)
if self.ret_unique_cnt:
ret.append(unique_cnt)
if len(ret) == 1:
return ret[0]
else:
return tuple(ret)