in flowtorch/ops/__init__.py [0:0]
def _select_bins(x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
Performs gather to select the bin in the correct way on batched inputs
"""
idx = idx.clamp(min=0, max=x.size(-1) - 1)
# Broadcast dimensions of idx over x
# idx ~ (batch_dims, input_dim, 1)
# x ~ (context_batch_dims, input_dim, count_bins)
# Note that by convention, the context variable batch dimensions must broadcast
# over the input batch dimensions.
if len(idx.shape) >= len(x.shape):
x = x.reshape((1,) * (len(idx.shape) - len(x.shape)) + x.shape)
x = x.expand(idx.shape[:-2] + (-1,) * 2)
return x.gather(-1, idx).squeeze(-1)