in training/utils/distributed.py [0:0]
def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1":
return all_gather_via_filesys(
data, filesys_save_dir, gather_to_rank_0_only=True
)
if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys:
return all_gather_via_filesys(data, filesys_save_dir)
cpu_group = None
if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu:
cpu_group = _get_global_gloo_group()
buffer = io.BytesIO()
torch.save(data, buffer)
data_view = buffer.getbuffer()
device = "cuda" if cpu_group is None else "cpu"
tensor = torch.ByteTensor(data_view).to(device)
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
size_list = [
torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)
]
if cpu_group is None:
dist.all_gather(size_list, local_size)
else:
print("gathering on cpu")
dist.all_gather(size_list, local_size, group=cpu_group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
assert isinstance(local_size.item(), int)
local_size = int(local_size.item())
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
if local_size != max_size:
padding = torch.empty(
size=(max_size - local_size,), dtype=torch.uint8, device=device
)
tensor = torch.cat((tensor, padding), dim=0)
if cpu_group is None:
dist.all_gather(tensor_list, tensor)
else:
dist.all_gather(tensor_list, tensor, group=cpu_group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
buffer = io.BytesIO(tensor.cpu().numpy())
obj = torch.load(buffer)
data_list.append(obj)
return data_list