in shap_e/models/transmitter/channels_encoder.py [0:0]
def sample_pcl_fps(points: torch.Tensor, data_ctx: int, method: str = "fps") -> torch.Tensor:
"""
Run farthest-point sampling on a batch of point clouds.
:param points: batch of shape [N x num_points].
:param data_ctx: subsample count.
:param method: either 'fps' or 'first'. Using 'first' assumes that the
points are already sorted according to FPS sampling.
:return: batch of shape [N x min(num_points, data_ctx)].
"""
n_points = points.shape[1]
if n_points == data_ctx:
return points
if method == "first":
return points[:, :data_ctx]
elif method == "fps":
batch = points.cpu().split(1, dim=0)
fps = [sample_fps(x, n_samples=data_ctx) for x in batch]
return torch.cat(fps, dim=0).to(points.device)
else:
raise ValueError(f"unsupported farthest-point sampling method: {method}")