optimum/tpu/distributed_model.py (79 lines of code) (raw):

# ruff: noqa: E402 import os from enum import Enum from loguru import logger os.environ["PJRT_DEVICE"] = "TPU" import torch.multiprocessing as mp import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp from optimum.tpu.modeling import AutoModelForCausalLM from .xla_mp_comm import AgentMailbox, RootMailbox class ModelCommand(Enum): LEAVE = 0 PREFILL = 1 DECODE = 2 def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable): device = xm.xla_device() world_size = xm.xrt_world_size() # create agent mailbox out of root's one mailbox = AgentMailbox(root_mailbox) logger.debug( f"Rank {rank} on {device} real device {xm.xla_real_devices([device])} ordinal {xm.get_ordinal()} " + f"world size {world_size}" ) # Model loading and sharding should happen here model = AutoModelForCausalLM.from_pretrained(model_id) model = model.eval() model.to(device) def get_next_token(inputs): # move inputs to device in a new dict to avoid conflicts model_inputs = {} for key, value in inputs.items(): model_inputs[key] = value.to(device) outputs = model(**model_inputs, return_dict=False)[0] xm.mark_step() # consider adding a rendezvous here if rank == 0: logger.debug(f"Rank {rank} getting tokens") next_token = sample_fn(outputs) xm.mark_step() logger.debug(f"Rank {rank} sending next_tokens {next_token.shape}") # Data needs to be moved to CPU before setting it mailbox.send(next_token.cpu()) while True: if rank == 0: mailbox.agent_ready.set() logger.debug(f"Rank {rank} waiting for commands") mailbox.receive() # Wait for rank 0 to receive command xm.rendezvous("start") logger.debug(f"Rank {rank} waiting for command at rendezvous") command, data = mailbox.command_data inputs = data[0] if data else None if command == ModelCommand.PREFILL: logger.debug(f"Rank {rank} PREFILL") get_next_token(inputs) elif command == ModelCommand.DECODE: logger.debug(f"Rank {rank} DECODE") get_next_token(inputs) elif command == ModelCommand.LEAVE: logger.debug(f"Rank {rank} LEAVE") # Set model to ready mailbox.agent_ready.set() break def model_loop_fn(*args): """Spawn processes in the TPUs forwarding arguments""" xmp.spawn(_mp_fn, args=(args), join=True, daemon=False) class DistributedModel: def __init__(self, model_id: str, sample_fn: callable): manager = mp.Manager() self.mailbox = RootMailbox(manager) self.model_loop = mp.Process(target=model_loop_fn, args=(model_id, self.mailbox, sample_fn)) self.model_loop.start() def prefill(self, **model_args): assert self.mailbox is not None, "DistributedModel is not initialized" return self.mailbox.send(ModelCommand.PREFILL, model_args)[0] def decode(self, **model_args): assert self.mailbox is not None, "DistributedModel is not initialized" return self.mailbox.send(ModelCommand.PREFILL, model_args)[0] def leave(self): if self.mailbox is None: return self.mailbox.send(ModelCommand.LEAVE) logger.debug("Joining...") self.model_loop.join() logger.debug("Model loop finished") self.mailbox = None def __del__(self): self.leave()