def sample_pcl_fps()

in shap_e/models/transmitter/channels_encoder.py [0:0]


def sample_pcl_fps(points: torch.Tensor, data_ctx: int, method: str = "fps") -> torch.Tensor:
    """
    Run farthest-point sampling on a batch of point clouds.

    :param points: batch of shape [N x num_points].
    :param data_ctx: subsample count.
    :param method: either 'fps' or 'first'. Using 'first' assumes that the
                   points are already sorted according to FPS sampling.
    :return: batch of shape [N x min(num_points, data_ctx)].
    """
    n_points = points.shape[1]
    if n_points == data_ctx:
        return points
    if method == "first":
        return points[:, :data_ctx]
    elif method == "fps":
        batch = points.cpu().split(1, dim=0)
        fps = [sample_fps(x, n_samples=data_ctx) for x in batch]
        return torch.cat(fps, dim=0).to(points.device)
    else:
        raise ValueError(f"unsupported farthest-point sampling method: {method}")