optimum/tpu/xla_mp_comm.py (39 lines of code) (raw):
from multiprocessing.managers import ListProxy
from typing import List
import torch.multiprocessing as mp
class RootMailbox:
"""A simple multiprocessing mailbox to communicate between the root process and the agents."""
def __init__(self, manager: mp.Manager):
self.root_bell = manager.Event()
self.root_command = manager.list()
self.agent_ready = manager.Event()
self.output_data = manager.list()
self.agent_error = manager.Event()
self.agent_error.clear()
def send(self, command: int, *args) -> ListProxy:
"""Send a command and arguments to the agents and wait for the response.
Args:
command (int): Command to send to the agents.
*args: Arguments to send to the agents.
Returns:
A list containing the response from the agents.
"""
# First wait until agent is ready to receive commands
self.agent_ready.wait()
self.agent_ready.clear()
self.root_command[:] = [command, *args]
self.root_bell.set()
# wait again until agent is ready, meaning command has been processed
self.agent_ready.wait()
if self.agent_error.is_set():
raise RuntimeError("Error on one of threads, stopping.")
ret = self.output_data
return ret
class AgentMailbox:
"""The agent mailbox to communicate with the root process."""
def __init__(self, root_mailbox: RootMailbox):
self.root_bell = root_mailbox.root_bell
self.root_command = root_mailbox.root_command
self.agent_ready = root_mailbox.agent_ready
self.output_data = root_mailbox.output_data
self.agent_error = root_mailbox.agent_error
def receive(self) -> ListProxy:
"""Wait for a command from the root process and return it.
Returns:
A list containing the command and arguments from the root process.
"""
self.root_bell.wait()
self.root_bell.clear()
return self.root_command
def send(self, *data):
"""Send the response to the root process.
Args:
*data: Data to send to the root process.
"""
self.output_data[:] = [*data]
@property
def command_data(self) -> tuple[int, List]:
"""Property helper to split command and arguments sent by the root process.
Returns:
A tuple containing the command and arguments.
"""
command = self.root_command[0]
data = self.root_command[1:]
return command, data