in torchbiggraph/parameter_sharing.py [0:0]
def start(self, groups: List["td.ProcessGroup"]) -> None:
self.groups = (
[groups[idx] for idx in self.group_idxs]
if self.group_idxs is not None
else None
)
join_count = 0
while True:
# 1. receive the command
cmd_buffer = torch.full((6,), -1, dtype=torch.long)
rank = td.recv(cmd_buffer)
cmd = cmd_buffer[0].item()
if cmd == STORE_CMD:
key = self._recv_key(rank, cmd_buffer[1].item())
self.handle_store(
rank,
key,
cmd_buffer[2].item(),
cmd_buffer[3].item(),
cmd_buffer[4].item(),
cmd_buffer[5].item(),
)
elif cmd == GET_CMD:
key = self._recv_key(rank, cmd_buffer[1].item())
self.handle_get(rank, key, cmd_buffer[2].item())
elif cmd == SWAP_CMD:
key = self._recv_key(rank, cmd_buffer[1].item())
self.handle_store(
rank,
key,
cmd_buffer[2].item(),
cmd_buffer[3].item(),
cmd_buffer[4].item(),
cmd_buffer[5].item(),
)
self.handle_get(rank, key, False)
elif cmd == JOIN_CMD:
join_count += 1
logger.info(f"ParameterServer join: join_count= {join_count}")
if join_count == self.num_clients:
for r in range(self.num_clients):
# after sending the join cmd,
# each client waits on this ack to know everyone is done
# and it's safe to exit
td.send(torch.zeros((1,)), dst=r)
do_barrier = cmd_buffer[1].item()
if do_barrier:
logger.info("ParameterServer barrier begin")
td.barrier(self.groups[0])
logger.info("ParameterServer barrier end")
break
else:
raise RuntimeError(
"Command is unknown value %d from rank %d." % (cmd, rank)
)