in shap_e/models/nn/ops.py [0:0]
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_points: sample points feature data, [B, d_hidden[-1], n_point]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(
self.n_point,
self.radius,
self.n_sample,
xyz,
points,
deterministic=not self.training,
fps_method=self.fps_method,
)
# new_xyz: sampled points position data, [B, n_point, C]
# new_points: sampled points data, [B, n_point, n_sample, C+D]
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, n_sample, n_point]
for i, conv in enumerate(self.mlp_convs):
new_points = self.act(self.apply_conv(new_points, conv))
new_points = new_points.mean(dim=2)
return new_points