def dim_select()

in hucc/utils.py [0:0]


def dim_select(input: th.Tensor, dim: int, index: th.Tensor):
    # TODO this is a bunch of special cases for now... figure out how to
    # generalize it?
    if input.ndim == 2 and index.ndim == 1 and dim == 1:
        return input.gather(1, index.view(-1, 1)).squeeze(1)
    elif input.ndim == 3 and index.ndim == 1 and dim == 0:
        index = index.view(1, -1, 1).expand(1, index.shape[0], input.shape[2])
        return input.gather(0, index).view(-1, input.shape[-1])
    elif input.ndim == 3 and index.ndim == 1 and dim == 1:
        index = index.view(-1, 1, 1).expand(index.shape[0], 1, input.shape[-1])
        return input.gather(1, index).view(-1, input.shape[-1])
    elif input.ndim == 3 and index.ndim == 2 and dim == 1:
        index = index.unsqueeze(-1).expand(*index.shape, input.shape[-1])
        return input.gather(1, index)
    else:
        raise ValueError('Can\'t dim_select this combination of tensors')