in graphlearn_torch/python/distributed/dist_sampling_producer.py [0:0]
def _sampling_worker_loop(rank,
data: DistDataset,
sampler_input: Union[NodeSamplerInput, EdgeSamplerInput],
unshuffled_index: Optional[torch.Tensor],
sampling_config: SamplingConfig,
worker_options: _BasicDistSamplingWorkerOptions,
channel: ChannelBase,
task_queue: mp.Queue,
sampling_completed_worker_count: mp.Value,
mp_barrier):
r""" Subprocess work loop for sampling worker.
"""
dist_sampler = None
try:
init_worker_group(
world_size=worker_options.worker_world_size,
rank=worker_options.worker_ranks[rank],
group_name='_sampling_worker_subprocess'
)
if worker_options.use_all2all:
torch.distributed.init_process_group(
backend='gloo',
timeout=datetime.timedelta(seconds=worker_options.rpc_timeout),
rank=worker_options.worker_ranks[rank],
world_size=worker_options.worker_world_size,
init_method='tcp://{}:{}'.format(worker_options.master_addr, worker_options.master_port)
)
if worker_options.num_rpc_threads is None:
num_rpc_threads = min(data.num_partitions, 16)
else:
num_rpc_threads = worker_options.num_rpc_threads
current_device = worker_options.worker_devices[rank]
ensure_device(current_device)
_set_worker_signal_handlers()
torch.set_num_threads(num_rpc_threads + 1)
init_rpc(
master_addr=worker_options.master_addr,
master_port=worker_options.master_port,
num_rpc_threads=num_rpc_threads,
rpc_timeout=worker_options.rpc_timeout
)
if sampling_config.seed is not None:
seed_everything(sampling_config.seed)
dist_sampler = DistNeighborSampler(
data, sampling_config.num_neighbors, sampling_config.with_edge,
sampling_config.with_neg, sampling_config.with_weight,
sampling_config.edge_dir, sampling_config.collect_features, channel,
worker_options.use_all2all, worker_options.worker_concurrency,
current_device, seed=sampling_config.seed
)
dist_sampler.start_loop()
if unshuffled_index is not None:
unshuffled_index_loader = DataLoader(
unshuffled_index, batch_size=sampling_config.batch_size,
shuffle=False, drop_last=sampling_config.drop_last
)
else:
unshuffled_index_loader = None
mp_barrier.wait()
keep_running = True
while keep_running:
try:
command, args = task_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if command == MpCommand.SAMPLE_ALL:
seeds_index = args
if seeds_index is None:
loader = unshuffled_index_loader
else:
loader = DataLoader(
seeds_index, batch_size=sampling_config.batch_size,
shuffle=False, drop_last=sampling_config.drop_last
)
if sampling_config.sampling_type == SamplingType.NODE:
for index in loader:
dist_sampler.sample_from_nodes(sampler_input[index])
elif sampling_config.sampling_type == SamplingType.LINK:
for index in loader:
dist_sampler.sample_from_edges(sampler_input[index])
elif sampling_config.sampling_type == SamplingType.SUBGRAPH:
for index in loader:
dist_sampler.subgraph(sampler_input[index])
dist_sampler.wait_all()
with sampling_completed_worker_count.get_lock():
sampling_completed_worker_count.value += 1 # non-atomic, lock is necessary
elif command == MpCommand.STOP:
keep_running = False
else:
raise RuntimeError("Unknown command type")
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass
if dist_sampler is not None:
dist_sampler.shutdown_loop()
shutdown_rpc(graceful=False)