in src/losses.py [0:0]
def gather_tensors_from_all(tensor):
"""
Wrapper over torch.distributed.all_gather for performing
'gather' of 'tensor' over all processes in both distributed /
non-distributed scenarios.
"""
if tensor.ndim == 0:
# 0 dim tensors cannot be gathered. so unsqueeze
tensor = tensor.unsqueeze(0)
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and (torch.distributed.get_world_size() > 1)
):
tensor, orig_device = convert_to_distributed_tensor(tensor)
gathered_tensors = [
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(gathered_tensors, tensor)
gathered_tensors = [
convert_to_normal_tensor(_tensor, orig_device)
for _tensor in gathered_tensors
]
else:
gathered_tensors = [tensor]
return gathered_tensors