in src/gluonts/nursery/SCott/pts/model/forecast_generator.py [0:0]
def _extract_instances(x: Any) -> Any:
"""
Helper function to extract individual instances from batched
mxnet results.
For a tensor `a`
_extract_instances(a) -> [a[0], a[1], ...]
For (nested) tuples of tensors `(a, (b, c))`
_extract_instances((a, (b, c)) -> [(a[0], (b[0], c[0])), (a[1], (b[1], c[1])), ...]
"""
if isinstance(x, (np.ndarray, torch.Tensor)):
for i in range(x.shape[0]):
# yield x[i: i + 1]
yield x[i]
elif isinstance(x, tuple):
for m in zip(*[_extract_instances(y) for y in x]):
yield tuple([r for r in m])
elif isinstance(x, list):
for m in zip(*[_extract_instances(y) for y in x]):
yield [r for r in m]
elif x is None:
while True:
yield None
else:
assert False