in training/utils/distributed.py [0:0]
def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors), similar to
`all_gather` above, but using filesystem instead of collective ops.
If gather_to_rank_0_only is True, only rank 0 will load the gathered object list
(and other ranks will have an empty list).
"""
world_size = get_world_size()
if world_size == 1:
return [data]
print("gathering via files")
cpu_group = _get_global_gloo_group()
# if unspecified, we will save to the current python file dir
if filesys_save_dir is not None:
save_dir = filesys_save_dir
elif "EXP_DIR" in os.environ:
save_dir = os.environ["EXP_DIR"]
else:
# try the same directory where the code is stored
save_dir = filesys_save_dir or os.path.dirname(__file__)
save_dir = os.path.join(save_dir, "all_gather_via_filesys")
if is_main_process():
os.makedirs(save_dir, exist_ok=True)
# use a timestamp and salt to distinguish different all_gather
timestamp = int(time.time()) if is_main_process() else 0
salt = random.randint(0, 2**31 - 1) if is_main_process() else 0
# broadcast the timestamp and salt across ranks
# (all-reduce will do the broadcasting since only rank 0 is non-zero)
timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long)
dist.all_reduce(timestamp_and_salt, group=cpu_group)
timestamp, salt = timestamp_and_salt.tolist()
# save the data to a file on the disk
rank_save = get_rank()
save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl"
save_data_path = os.path.join(save_dir, save_data_filename)
assert not os.path.exists(save_data_path), f"{save_data_path} already exists"
torch.save(data, save_data_path)
dist.barrier(group=cpu_group)
# read the data from the files
data_list = []
if rank_save == 0 or not gather_to_rank_0_only:
for rank_load in range(world_size):
load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl"
load_data_path = os.path.join(save_dir, load_data_filename)
assert os.path.exists(load_data_path), f"cannot read {save_data_path}"
data_list.append(torch.load(load_data_path))
dist.barrier(group=cpu_group)
# delete the saved file
os.remove(save_data_path)
return data_list