in seamseg/utils/parallel/scatter_gather.py [0:0]
def gather(outputs, target_device, dim=0):
r"""
Gathers tensors from different GPUs on a specified device
(-1 means the CPU).
"""
def gather_map(outputs):
out = outputs[0]
if isinstance(out, torch.Tensor):
return Gather.apply(target_device, dim, *outputs)
if out is None:
return None
if isinstance(out, dict):
if not all((len(out) == len(d) for d in outputs)):
raise ValueError('All dicts must have the same number of keys')
return type(out)(((k, gather_map([d[k] for d in outputs]))
for k in out))
if isinstance(out, PackedSequence):
return packed_sequence_gather(outputs, target_device)
return type(out)(map(gather_map, zip(*outputs)))
# Recursive function calls like this create reference cycles.
# Setting the function to None clears the refcycle.
try:
return gather_map(outputs)
finally:
gather_map = None