def _sampling_worker_loop()

in graphlearn_torch/python/distributed/dist_sampling_producer.py [0:0]


def _sampling_worker_loop(rank,
                          data: DistDataset,
                          sampler_input: Union[NodeSamplerInput, EdgeSamplerInput],
                          unshuffled_index: Optional[torch.Tensor],
                          sampling_config: SamplingConfig,
                          worker_options: _BasicDistSamplingWorkerOptions,
                          channel: ChannelBase,
                          task_queue: mp.Queue,
                          sampling_completed_worker_count: mp.Value,
                          mp_barrier):
  r""" Subprocess work loop for sampling worker.
  """
  dist_sampler = None
  try:
    init_worker_group(
      world_size=worker_options.worker_world_size,
      rank=worker_options.worker_ranks[rank],
      group_name='_sampling_worker_subprocess'
    )
    if worker_options.use_all2all:
      torch.distributed.init_process_group(
        backend='gloo',
        timeout=datetime.timedelta(seconds=worker_options.rpc_timeout),
        rank=worker_options.worker_ranks[rank],
        world_size=worker_options.worker_world_size,
        init_method='tcp://{}:{}'.format(worker_options.master_addr, worker_options.master_port)
      )

    if worker_options.num_rpc_threads is None:
      num_rpc_threads = min(data.num_partitions, 16)
    else:
      num_rpc_threads = worker_options.num_rpc_threads

    current_device = worker_options.worker_devices[rank]
    ensure_device(current_device)

    _set_worker_signal_handlers()
    torch.set_num_threads(num_rpc_threads + 1)

    init_rpc(
      master_addr=worker_options.master_addr,
      master_port=worker_options.master_port,
      num_rpc_threads=num_rpc_threads,
      rpc_timeout=worker_options.rpc_timeout
    )

    if sampling_config.seed is not None:
      seed_everything(sampling_config.seed)
    dist_sampler = DistNeighborSampler(
      data, sampling_config.num_neighbors, sampling_config.with_edge,
      sampling_config.with_neg, sampling_config.with_weight,
      sampling_config.edge_dir, sampling_config.collect_features, channel,
      worker_options.use_all2all, worker_options.worker_concurrency,
      current_device, seed=sampling_config.seed
    )
    dist_sampler.start_loop()

    if unshuffled_index is not None:
      unshuffled_index_loader = DataLoader(
        unshuffled_index, batch_size=sampling_config.batch_size,
        shuffle=False, drop_last=sampling_config.drop_last
      )
    else:
      unshuffled_index_loader = None

    mp_barrier.wait()

    keep_running = True
    while keep_running:
      try:
        command, args = task_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
      except queue.Empty:
        continue

      if command == MpCommand.SAMPLE_ALL:
        seeds_index = args
        if seeds_index is None:
          loader = unshuffled_index_loader
        else:
          loader = DataLoader(
            seeds_index, batch_size=sampling_config.batch_size,
            shuffle=False, drop_last=sampling_config.drop_last
          )

        if sampling_config.sampling_type == SamplingType.NODE:
          for index in loader:
            dist_sampler.sample_from_nodes(sampler_input[index])
        elif sampling_config.sampling_type == SamplingType.LINK:
          for index in loader:
            dist_sampler.sample_from_edges(sampler_input[index])
        elif sampling_config.sampling_type == SamplingType.SUBGRAPH:
          for index in loader:
            dist_sampler.subgraph(sampler_input[index])

        dist_sampler.wait_all()

        with sampling_completed_worker_count.get_lock():
          sampling_completed_worker_count.value += 1 # non-atomic, lock is necessary

      elif command == MpCommand.STOP:
        keep_running = False
      else:
        raise RuntimeError("Unknown command type")
  except KeyboardInterrupt:
    # Main process will raise KeyboardInterrupt anyways.
    pass

  if dist_sampler is not None:
    dist_sampler.shutdown_loop()
  shutdown_rpc(graceful=False)