in hucc/utils.py [0:0]
def dim_select(input: th.Tensor, dim: int, index: th.Tensor):
# TODO this is a bunch of special cases for now... figure out how to
# generalize it?
if input.ndim == 2 and index.ndim == 1 and dim == 1:
return input.gather(1, index.view(-1, 1)).squeeze(1)
elif input.ndim == 3 and index.ndim == 1 and dim == 0:
index = index.view(1, -1, 1).expand(1, index.shape[0], input.shape[2])
return input.gather(0, index).view(-1, input.shape[-1])
elif input.ndim == 3 and index.ndim == 1 and dim == 1:
index = index.view(-1, 1, 1).expand(index.shape[0], 1, input.shape[-1])
return input.gather(1, index).view(-1, input.shape[-1])
elif input.ndim == 3 and index.ndim == 2 and dim == 1:
index = index.unsqueeze(-1).expand(*index.shape, input.shape[-1])
return input.gather(1, index)
else:
raise ValueError('Can\'t dim_select this combination of tensors')