def _select_bins()

in flowtorch/ops/__init__.py [0:0]


def _select_bins(x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
    """
    Performs gather to select the bin in the correct way on batched inputs
    """
    idx = idx.clamp(min=0, max=x.size(-1) - 1)

    # Broadcast dimensions of idx over x
    # idx ~ (batch_dims, input_dim, 1)
    # x ~ (context_batch_dims, input_dim, count_bins)
    # Note that by convention, the context variable batch dimensions must broadcast
    # over the input batch dimensions.
    if len(idx.shape) >= len(x.shape):
        x = x.reshape((1,) * (len(idx.shape) - len(x.shape)) + x.shape)
        x = x.expand(idx.shape[:-2] + (-1,) * 2)

    return x.gather(-1, idx).squeeze(-1)