in graphlearn_torch/python/distributed/rpc.py [0:0]
def rpc_sync_data_partitions(num_data_partitions: int,
current_partition_idx: int):
r""" Synchronize the data partition info across all workers only in the
current role group.
Note that all data should be partitioned and used with a single role group.
Args:
num_data_partitions (int): The number of all data partitions.
current_partition_idx (int): The data partition idx that the current
process is responsible for, some compution tasks on this data partition
may be send to the current process from remote workers.
"""
ctx = get_context()
partition2workers = [[] for _ in range(num_data_partitions)]
gathered_results = all_gather(
(ctx.role, num_data_partitions, current_partition_idx))
for worker_name, (role, nparts, idx) in gathered_results.items():
if role != ctx.role:
raise RuntimeError(f"'rpc_sync_data_partition_mapping': inconsistent "
f"role type '{role}' gathered from {worker_name}, "
f"current role type is '{ctx.role}'.")
if nparts != num_data_partitions:
raise RuntimeError(f"'rpc_sync_data_partition_mapping': inconsistent "
f"data partition number '{nparts}' gathered from "
f"{worker_name}, the value on current process is "
f"'{ctx.role}'.")
partition2workers[idx].append(worker_name)
return partition2workers