def all_gather_via_filesys()

in training/utils/distributed.py [0:0]


def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False):
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors), similar to
    `all_gather` above, but using filesystem instead of collective ops.

    If gather_to_rank_0_only is True, only rank 0 will load the gathered object list
    (and other ranks will have an empty list).
    """
    world_size = get_world_size()
    if world_size == 1:
        return [data]

    print("gathering via files")
    cpu_group = _get_global_gloo_group()

    # if unspecified, we will save to the current python file dir
    if filesys_save_dir is not None:
        save_dir = filesys_save_dir
    elif "EXP_DIR" in os.environ:
        save_dir = os.environ["EXP_DIR"]
    else:
        # try the same directory where the code is stored
        save_dir = filesys_save_dir or os.path.dirname(__file__)
    save_dir = os.path.join(save_dir, "all_gather_via_filesys")
    if is_main_process():
        os.makedirs(save_dir, exist_ok=True)

    # use a timestamp and salt to distinguish different all_gather
    timestamp = int(time.time()) if is_main_process() else 0
    salt = random.randint(0, 2**31 - 1) if is_main_process() else 0
    # broadcast the timestamp and salt across ranks
    # (all-reduce will do the broadcasting since only rank 0 is non-zero)
    timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long)
    dist.all_reduce(timestamp_and_salt, group=cpu_group)
    timestamp, salt = timestamp_and_salt.tolist()

    # save the data to a file on the disk
    rank_save = get_rank()
    save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl"
    save_data_path = os.path.join(save_dir, save_data_filename)
    assert not os.path.exists(save_data_path), f"{save_data_path} already exists"
    torch.save(data, save_data_path)
    dist.barrier(group=cpu_group)

    # read the data from the files
    data_list = []
    if rank_save == 0 or not gather_to_rank_0_only:
        for rank_load in range(world_size):
            load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl"
            load_data_path = os.path.join(save_dir, load_data_filename)
            assert os.path.exists(load_data_path), f"cannot read {save_data_path}"
            data_list.append(torch.load(load_data_path))
    dist.barrier(group=cpu_group)

    # delete the saved file
    os.remove(save_data_path)
    return data_list