def rpc_sync_data_partitions()

in graphlearn_torch/python/distributed/rpc.py [0:0]


def rpc_sync_data_partitions(num_data_partitions: int,
                             current_partition_idx: int):
  r""" Synchronize the data partition info across all workers only in the
  current role group.

  Note that all data should be partitioned and used with a single role group.

  Args:
    num_data_partitions (int): The number of all data partitions.
    current_partition_idx (int): The data partition idx that the current
      process is responsible for, some compution tasks on this data partition
      may be send to the current process from remote workers.
  """
  ctx = get_context()
  partition2workers  = [[] for _ in range(num_data_partitions)]
  gathered_results = all_gather(
    (ctx.role, num_data_partitions, current_partition_idx))
  for worker_name, (role, nparts, idx) in gathered_results.items():
    if role != ctx.role:
      raise RuntimeError(f"'rpc_sync_data_partition_mapping': inconsistent "
                         f"role type '{role}' gathered from {worker_name}, "
                         f"current role type is '{ctx.role}'.")
    if nparts != num_data_partitions:
      raise RuntimeError(f"'rpc_sync_data_partition_mapping': inconsistent "
                         f"data partition number '{nparts}' gathered from "
                         f"{worker_name}, the value on current process is "
                         f"'{ctx.role}'.")
    partition2workers[idx].append(worker_name)
  return partition2workers