in distributed/rpc/ddp_rpc/main.py [0:0]
def run_worker(rank, world_size):
r"""
A wrapper function that initializes RPC, calls the function, and shuts down
RPC.
"""
# We need to use different port numbers in TCP init_method for init_rpc and
# init_process_group to avoid port conflicts.
rpc_backend_options = TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = "tcp://localhost:29501"
# Rank 2 is master, 3 is ps and 0 and 1 are trainers.
if rank == 2:
rpc.init_rpc(
"master",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
remote_emb_module = RemoteModule(
"ps",
torch.nn.EmbeddingBag,
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
kwargs={"mode": "sum"},
)
# Run the training loop on trainers.
futs = []
for trainer_rank in [0, 1]:
trainer_name = "trainer{}".format(trainer_rank)
fut = rpc.rpc_async(
trainer_name, _run_trainer, args=(remote_emb_module, trainer_rank)
)
futs.append(fut)
# Wait for all training to finish.
for fut in futs:
fut.wait()
elif rank <= 1:
# Initialize process group for Distributed DataParallel on trainers.
dist.init_process_group(
backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"
)
# Initialize RPC.
trainer_name = "trainer{}".format(rank)
rpc.init_rpc(
trainer_name,
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
# Trainer just waits for RPCs from master.
else:
rpc.init_rpc(
"ps",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
# parameter server do nothing
pass
# block until all rpcs finish
rpc.shutdown()