def _index_tensor()

in src/gluonts/mx/distribution/distribution.py [0:0]


def _index_tensor(x: Tensor, item: Any) -> Tensor:
    """"""
    squeeze: List[int] = []
    if not isinstance(item, tuple):
        item = (item,)

    saw_ellipsis = False

    for i, item_i in enumerate(item):
        axis = i - len(item) if saw_ellipsis else i
        if isinstance(item_i, int):
            if item_i != -1:
                x = x.slice_axis(axis=axis, begin=item_i, end=item_i + 1)
            else:
                x = x.slice_axis(axis=axis, begin=-1, end=None)
            squeeze.append(axis)
        elif item_i == slice(None):
            continue
        elif item_i == Ellipsis:
            saw_ellipsis = True
            continue
        elif isinstance(item_i, slice):
            assert item_i.step is None
            start = item_i.start if item_i.start is not None else 0
            x = x.slice_axis(axis=axis, begin=start, end=item_i.stop)
        else:
            raise RuntimeError(f"invalid indexing item: {item}")
    if len(squeeze):
        x = x.squeeze(axis=tuple(squeeze))
    return x