def start()

in torchbiggraph/parameter_sharing.py [0:0]


    def start(self, groups: List["td.ProcessGroup"]) -> None:
        self.groups = (
            [groups[idx] for idx in self.group_idxs]
            if self.group_idxs is not None
            else None
        )
        join_count = 0
        while True:
            # 1. receive the command
            cmd_buffer = torch.full((6,), -1, dtype=torch.long)
            rank = td.recv(cmd_buffer)
            cmd = cmd_buffer[0].item()

            if cmd == STORE_CMD:
                key = self._recv_key(rank, cmd_buffer[1].item())
                self.handle_store(
                    rank,
                    key,
                    cmd_buffer[2].item(),
                    cmd_buffer[3].item(),
                    cmd_buffer[4].item(),
                    cmd_buffer[5].item(),
                )
            elif cmd == GET_CMD:
                key = self._recv_key(rank, cmd_buffer[1].item())
                self.handle_get(rank, key, cmd_buffer[2].item())
            elif cmd == SWAP_CMD:
                key = self._recv_key(rank, cmd_buffer[1].item())
                self.handle_store(
                    rank,
                    key,
                    cmd_buffer[2].item(),
                    cmd_buffer[3].item(),
                    cmd_buffer[4].item(),
                    cmd_buffer[5].item(),
                )
                self.handle_get(rank, key, False)
            elif cmd == JOIN_CMD:
                join_count += 1
                logger.info(f"ParameterServer join: join_count= {join_count}")
                if join_count == self.num_clients:
                    for r in range(self.num_clients):
                        # after sending the join cmd,
                        # each client waits on this ack to know everyone is done
                        # and it's safe to exit
                        td.send(torch.zeros((1,)), dst=r)
                    do_barrier = cmd_buffer[1].item()
                    if do_barrier:
                        logger.info("ParameterServer barrier begin")
                        td.barrier(self.groups[0])
                        logger.info("ParameterServer barrier end")
                    break
            else:
                raise RuntimeError(
                    "Command is unknown value %d from rank %d." % (cmd, rank)
                )