def _mp_fn()

in optimum/tpu/distributed_model.py [0:0]


def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable):
    device = xm.xla_device()
    world_size = xm.xrt_world_size()
    # create agent mailbox out of root's one
    mailbox = AgentMailbox(root_mailbox)

    logger.debug(
        f"Rank {rank} on {device} real device {xm.xla_real_devices([device])} ordinal {xm.get_ordinal()} "
        + f"world size {world_size}"
    )

    # Model loading and sharding should happen here
    model = AutoModelForCausalLM.from_pretrained(model_id)
    model = model.eval()
    model.to(device)

    def get_next_token(inputs):
        # move inputs to device in a new dict to avoid conflicts
        model_inputs = {}
        for key, value in inputs.items():
            model_inputs[key] = value.to(device)
        outputs = model(**model_inputs, return_dict=False)[0]
        xm.mark_step()
        # consider adding a rendezvous here
        if rank == 0:
            logger.debug(f"Rank {rank} getting tokens")
            next_token = sample_fn(outputs)
            xm.mark_step()
            logger.debug(f"Rank {rank} sending next_tokens {next_token.shape}")
            # Data needs to be moved to CPU before setting it
            mailbox.send(next_token.cpu())

    while True:
        if rank == 0:
            mailbox.agent_ready.set()
            logger.debug(f"Rank {rank} waiting for commands")
            mailbox.receive()
        # Wait for rank 0 to receive command
        xm.rendezvous("start")

        logger.debug(f"Rank {rank} waiting for command at rendezvous")
        command, data = mailbox.command_data
        inputs = data[0] if data else None
        if command == ModelCommand.PREFILL:
            logger.debug(f"Rank {rank} PREFILL")
            get_next_token(inputs)
        elif command == ModelCommand.DECODE:
            logger.debug(f"Rank {rank} DECODE")
            get_next_token(inputs)
        elif command == ModelCommand.LEAVE:
            logger.debug(f"Rank {rank} LEAVE")
            # Set model to ready
            mailbox.agent_ready.set()
            break