in dpr/utils/dist_utils.py [0:0]
def all_gather_list(data, group=None, max_size=16384):
"""Gathers arbitrary data from all nodes into a list.
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
data. Note that *data* must be picklable.
Args:
data (Any): data from the local worker to be gathered on other workers
group (optional): group of the collective
"""
SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size
enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + SIZE_STORAGE_BYTES > max_size:
raise ValueError(
'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size))
rank = get_rank()
world_size = get_world_size()
buffer_size = max_size * world_size
if not hasattr(all_gather_list, '_buffer') or \
all_gather_list._buffer.numel() < buffer_size:
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
buffer = all_gather_list._buffer
buffer.zero_()
cpu_buffer = all_gather_list._cpu_buffer
assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format(
256 ** SIZE_STORAGE_BYTES)
size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big')
cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes))
cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc))
start = rank * max_size
size = enc_size + SIZE_STORAGE_BYTES
buffer[start: start + size].copy_(cpu_buffer[:size])
all_reduce(buffer, group=group)
try:
result = []
for i in range(world_size):
out_buffer = buffer[i * max_size: (i + 1) * max_size]
size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big')
if size > 0:
result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist())))
return result
except pickle.UnpicklingError:
raise Exception(
'Unable to unpickle data from other workers. all_gather_list requires all '
'workers to enter the function together, so this error usually indicates '
'that the workers have fallen out of sync somehow. Workers can fall out of '
'sync if one of them runs out of memory, or if there are other conditions '
'in your training script that can cause one worker to finish an epoch '
'while other workers are still iterating over their portions of the data.'
)