in src/gluonts/mx/distribution/distribution.py [0:0]
def _index_tensor(x: Tensor, item: Any) -> Tensor:
""""""
squeeze: List[int] = []
if not isinstance(item, tuple):
item = (item,)
saw_ellipsis = False
for i, item_i in enumerate(item):
axis = i - len(item) if saw_ellipsis else i
if isinstance(item_i, int):
if item_i != -1:
x = x.slice_axis(axis=axis, begin=item_i, end=item_i + 1)
else:
x = x.slice_axis(axis=axis, begin=-1, end=None)
squeeze.append(axis)
elif item_i == slice(None):
continue
elif item_i == Ellipsis:
saw_ellipsis = True
continue
elif isinstance(item_i, slice):
assert item_i.step is None
start = item_i.start if item_i.start is not None else 0
x = x.slice_axis(axis=axis, begin=start, end=item_i.stop)
else:
raise RuntimeError(f"invalid indexing item: {item}")
if len(squeeze):
x = x.squeeze(axis=tuple(squeeze))
return x