import copy
import logging
import os
import sys
import time
import traceback
from bisect import bisect_left
from enum import Enum
from typing import Dict, List, Optional, Tuple

import torch
import torch.multiprocessing as mp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from transformers.generation import GenerationConfig

import optimum.tpu.xla_logger as logger
from optimum.tpu.generation import TokenSelector
from optimum.tpu.modeling import AutoModelForCausalLM
from optimum.tpu.static_cache_xla import StaticCacheXla
from optimum.tpu.xla_mp_comm import AgentMailbox, RootMailbox

from .generator_base import Generator
from .pb.generate_pb2 import (
  Batch,
  CachedBatch,
  FinishReason,
  GeneratedText,
  Generation,
  InfoResponse,
  NextTokenChooserParameters,
  Request,
  StoppingCriteriaParameters,
  Tokens,
)


# Disable optimum-tpu warnings as it seems to block the server after a while
optimum_logger = logging.getLogger("optimum.tpu")
optimum_logger.setLevel("CRITICAL")

# These will do some bucketing on prefill lengths to avoid too many different sizes
PREFILL_LENGTHS = list(range(6, 16)) + [
    16,
    32,
    64,
    128,
    256,
    512,
    1024,
    2048,
    4096,
    8192,
    16384,
    32768,
]

def take_nearest_length(length: int) -> int:
  """Gets the nearest length to the right in a set of lengths."""
  pos = bisect_left(PREFILL_LENGTHS, length)
  if pos == len(PREFILL_LENGTHS):
    return PREFILL_LENGTHS[-1]
  return PREFILL_LENGTHS[pos]

class Slot:
    """Represents a slot in a static batch"""

    class State(Enum):
        EMPTY = 0
        PAUSE = 1
        READY = 2

    def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase, device: [str, torch.device]):
        self._id = id
        self._tokenizer = tokenizer
        self.clear()
        self._device = device

    def clear(self):
        """Clear the slot and mark it as available."""
        self._state = Slot.State.EMPTY
        self._batch_id = None
        self._request_id = None
        self._inputs = ""
        self._generation_config = None
        self._tokens = []
        self._mask = None
        self._selector = None
        self._generated_tokens = 0
        self._next_text_token_start = 0
        self._next_text_token_end = 0
        self._generated_text = ""
        self._next_text = ""
        self._kv_cache = None
        self._truncate = 0
        self._position_id = 0

    @property
    def id(self) -> int:
        return self._id

    @property
    def state(self) -> "Slot.State":
        return self._state

    @property
    def batch_id(self) -> int:
        return self._batch_id

    @property
    def request_id(self) -> int:
        return self._request_id

    @property
    def cached_text(self) -> str:
        return self._inputs + self._generated_text

    @property
    def generation_config(self) -> GenerationConfig:
        return self._generation_config

    @property
    def generated_tokens(self) -> int:
        return self._generated_tokens

    @property
    def truncate(self) -> int:
        return self._truncate

    @property
    def position_id(self) -> int:
        return self._position_id

    @position_id.setter
    def position_id(self, cur_pos: int):
        self._position_id = cur_pos

    @property
    def cache_position(self) -> int:
        # This corresponds to the cache position for this slot
        return self._next_text_token_start


    def assign(self, batch_id: int, request: Request, generation_config: GenerationConfig):
        """Assign a request to a slot.

        Args:
            batch_id (`int`): The id of the batch containing the request.
            request (`Request`):
                The request to be assigned. Contains the inputs and tokens selection parameters.
            generation_config (`transformers.GenerationConfig`):
                The base generation config (might be modified by the request generation parameters).
        """
        self._state = Slot.State.READY
        self._batch_id = batch_id
        self._request_id = request.id
        self._inputs = request.inputs
        self._generation_config = copy.deepcopy(generation_config)
        # Update generation config with token chooser parameters
        self._generation_config.temperature = request.parameters.temperature
        self._generation_config.top_k = request.parameters.top_k
        self._generation_config.top_p = request.parameters.top_p
        self._generation_config.typical_p = request.parameters.typical_p
        self._generation_config.do_sample = request.parameters.do_sample
        self._generation_config.repetition_penalty = request.parameters.repetition_penalty
        self._truncate = request.truncate
        self.seed = request.parameters.seed
        # TODO: watermark
        self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens
        self._max_new_tokens = self._generation_config.max_new_tokens
        # TODO: stop_sequences, ignore_eos_token

    def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor = None, selector: TokenSelector = None):
        """Reset the slot for the next generation.

        Args:
            input_ids: (`torch.LongTensor`):
                The new input_ids to use to generate the next token.
            attention_mask: (`torch.LongTensor`):
                The new attention_mask to use to generate the next token.
            selector: (`TokenSelector`):
                An object implementing the updated token selection logic.
        """
        self._tokens = input_ids.cpu()
        self._next_text_token_start = 0
        self._next_text_token_end = torch.numel(self._tokens)
        self._next_text = ""
        if attention_mask is not None:
            self._mask = attention_mask.clone()
        else:
            self._mask = None
        self._selector = selector

    def pause(self):
        """Mark the current slot as paused for generation.

        Note that the KV cache for this slot will still be filled.
        """
        # Drop the last token as it will be added back when resuming the slot
        self._generated_tokens -= 1
        # Since generated tokens are now part of the prefill, we need to reevaluate
        # max_new_tokens for the next generation
        self._generation_config.max_new_tokens = self._max_new_tokens - self._generated_tokens
        self._state = Slot.State.PAUSE

    def resume(self):
        """Mark the slot as ready for generation."""
        self._state = Slot.State.READY

    def _decode_next_tokens(
        self,
    ) -> str:
        """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
        # Copy the tokens to CPU to avoid recompilation on TPU. Post-processing is quite fast anyway.
        tokens = self._tokens.cpu()
        # We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode
        # which decide to add a space or not depending on the surrounding ids.
        new_text = self._tokenizer.decode(tokens[self._next_text_token_start :], skip_special_tokens=False)
        if new_text.endswith("�"):
            # utf-8 char at the end means it's a potential unfinished byte sequence
            # from byte fallback tokenization.
            return ""

        # Compare the generated text with the one using only the tokens producing the last one
        last_text = self._tokenizer.decode(
            tokens[self._next_text_token_start : self._next_text_token_end],
            skip_special_tokens=False,
        )
        if len(new_text) == len(last_text):
            # Nothing new was actually generated
            return ""
        # Return the decoded text and store its token offsets
        self._next_text_token_start = self._next_text_token_end
        self._next_text_token_end = torch.numel(tokens)
        return new_text[len(last_text) :]

    def append(self, next_token: int) -> str:
        """Append a new generated token to this slot

        The new token is added to the list of generated tokens, which impacts
        directly the generated_text and stopped property.

        The new token is however not added immediately to the slot inputs: it will
        be added later on when it has effectively been used to produce the next token.

        Args:
            next_token (`int`):
                The newly generated token.

        Return:
            The corresponding decoded text (if any).
        """
        self._tokens = torch.cat([self._tokens, torch.tensor([next_token], dtype=self._tokens.dtype)])
        # Update mask only if it was set previously
        if self._mask is not None:
            self._mask = torch.cat([self._mask, torch.tensor([1], dtype=self._mask.dtype)])
        self._generated_tokens += 1
        next_text = self._decode_next_tokens()
        # Now that a new token has been generated, we can append the previous one to the generated text
        self._generated_text += self._next_text
        self._next_text = next_text
        return next_text

    def select(self, input_ids: torch.LongTensor, logits: torch.Tensor) -> torch.LongTensor:
        """Select the next token from the candidate logits.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation (not used in all generation modes).
            logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
                The logits corresponding to the generated tokens.

        Return:
            `torch.LongTensor`: A scalar torch.LongTensor` containing the selected token.
        """
        return self._selector.select(input_ids, logits)[0]

    @property
    def stopped(self) -> bool:
        # unsqueeze tokens to avoid problems with stopping criteria
        tokens = self._tokens.unsqueeze(0)
        return bool(torch.all(self._selector.stopping_criteria(tokens, None)))

    @property
    def generated_text(self) -> str:
        return self._generated_text + self._next_text

    @property
    def next_token(self) -> int:
        return None if len(self._tokens) == 0 else self._tokens[-1]

    @property
    def attention_mask(self) -> torch.LongTensor:
        return self._mask

    @property
    def max_token(self) -> int:
        return self._generation_config.max_length

    @property
    def max_new_tokens(self) -> int:
        # The current value of max_new_tokens: might be different of the target max_new_tokens
        # if the slot has been paused and resumed.
        return self._generation_config.max_new_tokens


class TpuGeneratorSingleThread(Generator):
    """A Generator for models running on TPU, single threaded."""

    def __init__(
        self,
        model,
        tokenizer: PreTrainedTokenizerBase,
    ):
        self.model = model
        # Specify padding options for decoder-only architecture
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.padding_side = "left"
        tokenizer.truncation_side = "left"
        self.tokenizer = tokenizer
        self.special_tokens = self.tokenizer.all_special_ids
        # The token selector will use the model's generation mixin internal variables to select the next token, and it
        # expects special tokens to be initialized in the model.
        model._prepare_special_tokens(generation_config=model.generation_config, device=model.device)
        # Slots are empty to begin with, they will be populated as new batches arrive
        self.slots = []
        self.batch_id = 0
        # Note: this index will _never_ be decremented, and that's fine.
        self.slot_index = 0
        self.past_key_values = None
        # _supports_static_cache is specific to some models (e.g.: Gemma and Llama).
        self._supports_static_cache = True
        if getattr(self.model, "_supports_static_cache", False) is False:
            logger.warning(
                f"Static cache not available for {self.model.__class__.__name__}. Performance will be affected"
            )
            self._supports_static_cache = False
        # compile model when possible to accelerate decoding
        if model.device.type == "xla" and ("DBG_COMPILE" in os.environ):
            self.model_one_token = torch.compile(model, backend="openxla")
            logger.debug("Model compiled for decoding")
        else:
            self.model_one_token = model

    @property
    def info(self) -> InfoResponse:
        """Returns the expected InfoResponse."""
        dtype = getattr(self.model.config, "torch_dtype", "float32")
        return InfoResponse(
            requires_padding=True,
            dtype=str(dtype),
            device_type="xla",
        )

    def _create_dummy_request(self, max_tokens: int) -> Batch:
        """Create a dummy request for warmup."""
        # Generate a random input with slightly more tokens than requested, because special tokens are going to be
        # skipped.
        MARGIN = 10
        input_tokens = torch.randint(self.model.config.vocab_size, (1, max_tokens + MARGIN), dtype=torch.int64)
        text = self.tokenizer.decode(input_tokens[0], skip_special_tokens=True)
        # These are just dummy params to allow Request creation
        parameters = NextTokenChooserParameters(
            temperature=1.0,
            top_k=None,
            top_p=None,
            do_sample=False,
            seed=None,
            repetition_penalty=1.0,
            typical_p=1.0,
        )
        stopping_parameters = StoppingCriteriaParameters(max_new_tokens=20, ignore_eos_token=True)
        dummy_request = Request(
            id=0,
            inputs=text,
            truncate=max_tokens,
            parameters=parameters,
            stopping_parameters=stopping_parameters,
        )
        return dummy_request


    def warmup(self, batch: Batch) -> int:
        """Verify if the hardware can support the target load.

        Args:
            batch (`Batch`):
                A batch corresponding to the maximum number of concurrent requests.

        Return:
            The maximum number of tokens the model supports.
        """
        logger.debug("Warming up the model")
        start = time.time()
        # Just check that the warmup request parameters match the model capacity
        # NOTE: later self.model.config.batch_size might become self.model.config.max_batch_size.
        if self.model.config.batch_size is not None:
            batch_size = self.model.config.batch_size
        else:
            # batch size is not set, just assume it's unlimited and accept all requests
            batch_size = len(batch.requests)
        if len(batch.requests) > batch_size:
            raise ValueError(
                f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length."
            )

        # Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible
        # batch sizes and sequence lengths.
        seq_len = self.model.config.sequence_length
        if os.environ.get("SKIP_WARMUP", "0") == "1":
            logger.debug("Skipping warmup")
            return batch_size * seq_len
        bucket_seq_len = take_nearest_length(seq_len)
        requests = [self._create_dummy_request(seq_len) for _ in range(batch_size)]
        for _ in reversed(range(batch_size)):
            # Prefill with different truncate sizes to test all prefill lengths. List is reversed so first longest
            # sequences are tested and, if there is a memory failure, that will appear sooner.
            for l in reversed(PREFILL_LENGTHS):
                # Skip all the unsupported lengths
                if l > bucket_seq_len:
                    continue
                # Set all truncate values for all requests
                for r in requests:
                    r.truncate = l
                    r.stopping_parameters.max_new_tokens = 10
                warmup_batch = Batch(id=0,
                                     requests=requests,
                                     size=len(requests),
                                     max_tokens=batch.max_tokens)
                logger.debug(f"Warmup for {len(requests)} requests, truncate value {l} seq_len {seq_len}")
                _generations, next_batch = self.prefill(warmup_batch)
                if next_batch is not None:
                    self.decode([next_batch])
                else:
                    logger.debug(f"No decode on warmup for {len(requests)}x{l}")
                self.clear()
            # remove the last requests to decrease the batch size
            requests.pop()

        elapsed = time.time() - start
        logger.debug(f"Warmup done, took {elapsed:.2f}s")
        return batch_size * seq_len

    @torch.no_grad
    def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
        """Prefill new requests.

        Args:
            batch (`Batch`):
                A batch containing the new requests.

        Return:
            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
        """
        slots = {state: [] for state in Slot.State}
        for slot in self.slots:
            slots[slot.state].append(slot)
        active_slots = slots[Slot.State.READY]
        # Delete all empty slots, no need to have them anymore
        empty_slots = slots[Slot.State.EMPTY]
        model_batch_size = self.model.config.batch_size
        if model_batch_size is not None and model_batch_size < len(active_slots) + len(batch.requests):
            # If raising an error here wouldn't crash the server, we could raise a ValueError
            error = ValueError(
                f"Cannot prefill {len(batch.requests)} new request(s)."
                f" Maximum batch size supported is: {model_batch_size}."
            )
            # but since it's not possible, we just log the error and return an empty generation
            logger.error(error)
            return [], None
        for slot in empty_slots:
            self.slots.remove(slot)
        # Assign each request to an empty slot
        logger.debug(f"Prefilling {len(batch.requests)} new request(s) adding to {len(active_slots)} active slot(s)")
        for request in batch.requests:
            # Dynamically create a new slot for each request
            slot = Slot(self.slot_index, self.tokenizer, self.model.device)
            self.slot_index += 1
            slot.assign(self.batch_id, request, self.model.generation_config)
            self.slots.append(slot)
            logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
            logger.debug(
                f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
            )
        # Reconstruct the full inputs (without padding) as seen by the model.
        # This comprises:
        # - the inputs for new requests,
        # - the inputs and the generated text that has already been cached (i.e. excluding the last generated token)
        #   for unfinished requests.

        # Prepare inputs. They need to be tokenized and truncated afterwards.
        max_len = 0
        batch_inputs = []
        for slot in self.slots:
            batch_inputs.append(slot.cached_text)
            max_len = max(max_len, slot.truncate)
        if max_len == 0:
            max_len = self.model.config.sequence_length
        tokenized_inputs = self.tokenizer(batch_inputs,
                                          return_tensors="pt",
                                          padding=True,
                                          truncation=True,
                                          max_length=max_len)
        seq_length = tokenized_inputs.input_ids.size(-1)
        seq_length = min(seq_length, self.model.config.sequence_length)
        batch_size = len(self.slots)
        # Initialize input_ids and attention_mask with padding (to make them all the same size)
        input_ids = torch.full((batch_size, seq_length), self.tokenizer.pad_token_id, dtype=torch.int64)
        attention_mask = torch.full((batch_size, seq_length), 0, dtype=torch.int64)

        # Pause previously active slots during generation and store their last token.
        next_tokens = []
        for slot in active_slots:
            next_tokens.append(slot.next_token)
            slot.pause()
        # Each slot must be reset with the padded inputs and masks
        for i, slot in enumerate(self.slots):
            assert slot.state != slot.state.EMPTY

            truncation = min(tokenized_inputs.input_ids.size(-1), input_ids.size(-1))
            if slot.truncate > 0:
                truncation = min(truncation, slot.truncate)
            input_ids[i, -truncation:] = tokenized_inputs.input_ids[i, -truncation:]
            slot_input_ids = input_ids[i : i + 1, :]
            # Padded input ids are also required to set logits processors and stopping criterias
            try:
                selector = TokenSelector.create(
                    slot_input_ids,
                    slot.generation_config,
                    self.model,
                    self.model.config.sequence_length,
                    seed=slot.seed,
                )
            except ValueError as e:
                # This is very unlikely, but it seems it could be possible if router does not check values beforehand.
                # In that case, we just skip the slot, and mark it as empty. This should prevent returning this to the
                # router.
                logger.error(f"Invalid generation parameters for slot {slot.id}. Skipping it. Error: {e}")
                slot.clear()
                continue
            slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
            attention_mask[i, -truncation:] = tokenized_inputs.attention_mask[i, -truncation:]
            if self._supports_static_cache:
                # Attention mask does not need to be tracked when using static cache
                slot_attention_mask = None
            else:
                slot_attention_mask = attention_mask[i]
            slot.reset(slot_input_ids, slot_attention_mask, selector)
        # Clear KV cache
        self.past_key_values = None
        # Obtain position ids using attention mask.
        position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
        # Save position id for every slot
        for slot, position_id in zip(self.slots, position_ids):
            slot.position_id = position_id.max().item() + 1

        extra_args = {}
        if self._supports_static_cache:
            self.past_key_values = StaticCacheXla(
                config=self.model.config,
                max_batch_size=len(self.slots),
                max_cache_len=self.model.config.sequence_length,
                device=self.model.device,
                dtype=self.model.dtype,
            )
            extra_args["cache_position"] = torch.arange(seq_length, device=self.model.device)
            extra_args["past_key_values"] = self.past_key_values
        else:
            # Reset/clear KV cache
            self.past_key_values = None
        generation, next_batch = self._generate_token(
            self.batch_id,
            input_ids.to(self.model.device),
            self.model,
            attention_mask=attention_mask.to(self.model.device),
            position_ids=position_ids.to(self.model.device),
            **extra_args,
        )
        self.batch_id += 1

        # Reactivate previously active slots for the next decode, and append
        # back their next token.
        for slot, next_token in zip(active_slots, next_tokens):
            slot.append(next_token)
            slot.resume()
        logger.debug("Model ready for decoding")
        if next_batch is not None:
            logger.debug(f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}")
        return generation, next_batch

    @torch.no_grad
    def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
        """Decode the specified prefilled requests.

        Args:
            batches (`List[CachedBatch]`):
                A list of previous batches containing the prefilled requests.

        Return:
            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
        """
        # batches contains a list composed of:
        # - the batch id returned by the last decode,
        # - the batch id(s) returned by the last prefill(s)
        # Batches are always concatenated during prefill, so we can
        # just carry on with decoding. We adopt the id of the first
        # batch in the list as our next batch id.
        next_batch_id = batches[0].id
        request_ids = []
        for batch in batches:
            request_ids += batch.request_ids
        cleared_request_ids = []
        for slot in self.slots:
            if slot.state == slot.State.READY and slot.request_id not in request_ids:
                cleared_request_ids.append(slot.request_id)
                slot.clear()
        if len(cleared_request_ids) > 0:
            logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.")
        active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
        if len(active_slots) < len(request_ids):
            logger.error("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
        # Reconstruct input_ids and attention_mask from slots
        input_ids = None
        attention_mask = None
        batch_size = len(self.slots)
        position_ids = torch.zeros(
            [batch_size, 1],
            dtype=torch.int64,
        )
        # init pad_token_id and input_ids
        pad_token_id = self.tokenizer.pad_token_id
        if pad_token_id is None:
            if isinstance(self.tokenizer.eos_token_id, list):
                pad_token_id = self.tokenizer.eos_token_id[0]
            else:
                pad_token_id = self.tokenizer.eos_token_id
        # Create blank inputs covering all slots (even empty ones)
        input_ids = torch.full(
            [batch_size, 1],
            fill_value=pad_token_id,
            dtype=torch.int64,
        )
        cache_position = torch.zeros([1], dtype=torch.int64)
        for i, slot in enumerate(self.slots):
            if slot.state != Slot.State.EMPTY:
                # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
                input_ids.index_put_([torch.tensor([i])], slot.next_token)
                if not self._supports_static_cache:
                    # When using dynamic cache, the whole attention mask needs to be passed over to the model at each iteration.
                    if attention_mask is None:
                        # Create default mask covering all slots (even empty ones)
                        attention_mask = torch.zeros(
                            [batch_size, slot.attention_mask.size(-1)],
                            dtype=torch.int64,
                        )
                    attention_mask.index_put_([torch.tensor([i])], slot.attention_mask)
                position_ids.index_put_([torch.tensor([i])], torch.tensor(slot.position_id))
                cache_position = torch.maximum(cache_position, torch.tensor([slot.cache_position]))
        if input_ids is None:
            raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
        extra_args = {}
        if self._supports_static_cache:
            extra_args["cache_position"] = position_ids.max().unsqueeze(0).to(self.model.device)
        else:
            extra_args["attention_mask"] = attention_mask.to(self.model.device)
        extra_args["past_key_values"] = self.past_key_values
        generations, next_batch = self._generate_token(
            next_batch_id,
            input_ids.to(self.model.device),
            self.model_one_token,
            position_ids=position_ids.to(self.model.device),
            **extra_args,
        )
        for slot, gen in zip(self.slots, generations):
            slot.position_id += len(gen.tokens.ids)

        return generations, next_batch

    def _generate_token(
        self, next_batch_id: int, input_ids: torch.LongTensor, model: torch.nn.Module, **forward_extra_params
    ) -> Tuple[List[Generation], CachedBatch]:
        # Add barrier to allow next graph step to always be the same
        xm.mark_step()
        # Forward
        outputs = model(
            input_ids,
            return_dict=True,
            use_cache=True,
            **forward_extra_params,
        )
        if not self._supports_static_cache:
            # Save KV cache
            self.past_key_values = outputs.past_key_values
        # Barrier for XLA model
        xm.mark_step()
        ret = self._post_generate(outputs, next_batch_id, input_ids)
        return ret

    def _post_generate(
        self, outputs: Dict, next_batch_id: int, input_ids: torch.LongTensor
    ) -> Tuple[List[Generation], CachedBatch]:
        generations = []
        active_slots = False
        for i, slot in enumerate(self.slots):
            if slot.state != Slot.State.READY:
                continue
            request_id = slot.request_id
            next_token_logits = outputs.logits[i : i + 1, -1, :]
            slot_input_ids = input_ids[i : i + 1, :]
            next_token = slot.select(slot_input_ids, next_token_logits)
            next_token_text = slot.append(next_token)
            generated_text = None
            finish_reason = None
            if next_token == self.tokenizer.eos_token_id:
                finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN
            elif slot.stopped:
                if slot.generated_tokens == slot.max_new_tokens:
                    finish_reason = FinishReason.FINISH_REASON_LENGTH
                else:
                    finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE
            if finish_reason is not None:
                # We must include the generated text for each finished sequence in the response
                generated_text = GeneratedText(
                    text=slot.generated_text, generated_tokens=slot.generated_tokens, finish_reason=finish_reason
                )
                logger.debug(f"Decode complete for request {request_id} with {slot.generated_tokens} tokens")
                # This slot is now empty, it will be removed from the list of
                # active slots once a new prefill is requested
                slot.clear()
            else:
                active_slots = True
            generations.append(
                Generation(
                    request_id=request_id,
                    prefill_tokens=None,
                    tokens=Tokens(
                        ids=[next_token],
                        logprobs=[0],
                        texts=[next_token_text],
                        is_special=[next_token in self.special_tokens],
                    ),
                    generated_text=generated_text,
                )
            )
        batch = None
        if active_slots:
            # Whatever initial batch these requests came from, we always return all pending requests in a single batch
            request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY]
            batch = self._cached_batch(next_batch_id, request_ids)
        else:
            logger.debug("No more pending requests")
        return generations, batch

    def _cached_batch(self, batch_id: int, request_ids: List):
        size = len(request_ids)
        max_tokens = size * self.model.config.sequence_length
        return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens)

    def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch:
        """Remove requests that are not listed from the specified batch

        Args:
            batch_id (`int`):
                The id of a cached batch.
            request_ids(`List[int]`):
                The list of requests that must be kept.

        Return:
            A `CachedBatch` containing the pending requests.
        """
        keep_slot_ids = [slot.id for slot in self.slots if slot.request_id in keep_request_ids]
        self._clear(keep_slot_ids)
        return self._cached_batch(batch_id, keep_request_ids)

    def clear(self, batch_id: Optional[int] = None):
        """Remove a subset or all requests from the generator"""
        keep_ids = []
        if batch_id is not None:
            keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id]
        return self._clear(keep_ids)

    def _clear(self, keep_slot_ids: List):
        for slot in self.slots:
            if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids:
                logger.debug(f"Removing slot {slot.id} with request {slot.request_id}")
                slot.clear()

    @classmethod
    def from_pretrained(cls, model_path: str, revision: str, max_batch_size: int, max_sequence_length: int):
        """Instantiate a TpuGenerator.

        Args:
            model_path (`str`):
                The path to a local model. This path must also contain a Tokenizer.
            revision (`str`):
                The revision of the model.
            max_batch_size (`int`):
                The maximum batch size.
            max_sequence_length (`int`):
                The maximum sequence length.

        Returns:
            A TpuGenerator.
        """
        logger.info("Loading model (this can take a few minutes).")
        start = time.time()
        model = AutoModelForCausalLM.from_pretrained(
            model_path, revision=revision, batch_size=max_batch_size, sequence_length=max_sequence_length
        )
        end = time.time()
        logger.info(f"Model successfully loaded in {end - start:.2f} s.")
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        return cls(model, tokenizer)


class GeneratorCommand(Enum):
    INFO = 0
    WARMUP = 1
    PREFILL = 2
    DECODE = 3
    FILTER = 4
    CLEAR = 5
    DELETE = -1


def _mp_fn(
    rank, model_path: str, revision: str, max_batch_size: int, max_sequence_length: int, root_mailbox: RootMailbox
):
    device = xm.xla_device()
    world_size = xm.xrt_world_size()
    # create agent mailbox out of root's one
    mailbox = AgentMailbox(root_mailbox)

    # re-init logger for each child process
    logger_level = os.environ.get("LOGGER_LEVEL_GENERATOR", "DEBUG")
    logger.logger.remove()
    logger.logger.add(
        sys.stdout,
        filter="text_generation_server",
        level=logger_level,
        backtrace=True,
        diagnose=False,
    )

    logger.debug(
        f"Rank {rank} on {device} real device {xm.xla_real_devices([device])} ordinal {xm.get_ordinal()} "
        + f"world size {world_size}"
    )

    generator = TpuGeneratorSingleThread.from_pretrained(model_path, revision, max_batch_size, max_sequence_length)
    # TODO: maybe model_config can be removed from mailbox

    def return_to_caller(*data):
        # consider adding a rendezvous here
        if rank == 0:
            xm.mark_step()
            mailbox.send(*data)

    while True:
        xm.rendezvous("start")
        if rank == 0:
            mailbox.agent_ready.set()
            mailbox.receive()
        # Wait for rank 0 to receive command
        xm.rendezvous("wait_command")
        command, data = mailbox.command_data
        logger.debug(f"Generator@{rank} {command.name}")
        try:
            if command == GeneratorCommand.INFO:
                info = generator.info
                return_to_caller(info.SerializeToString())
            if command == GeneratorCommand.WARMUP:
                batch = Batch.FromString(data[0])
                return_to_caller(generator.warmup(batch=batch))
            if command == GeneratorCommand.PREFILL:
                batch = Batch.FromString(data[0])
                generations, cached_batch = generator.prefill(batch=batch)
                s_cached_batch = cached_batch.SerializeToString() if cached_batch is not None else None
                return_to_caller([g.SerializeToString() for g in generations], s_cached_batch)
            if command == GeneratorCommand.DECODE:
                batches = [CachedBatch.FromString(b) for b in data[0]]
                generations, cached_batch = generator.decode(batches=batches)
                s_cached_batch = cached_batch.SerializeToString() if cached_batch is not None else None
                return_to_caller([g.SerializeToString() for g in generations], s_cached_batch)
            if command == GeneratorCommand.FILTER:
                batch_id, request_ids = data
                cached_batch = generator.filter(batch_id, request_ids)
                return_to_caller(cached_batch.SerializeToString())
            if command == GeneratorCommand.CLEAR:
                batch_id = data[0]
                generator.clear(batch_id)
            if command == GeneratorCommand.DELETE:
                if rank == 0:
                    # Set agent to ready
                    mailbox.agent_ready.set()
                break
        except Exception as e:
            logger.error(f"Error in command {command.name}")
            mailbox.agent_error.set()
            mailbox.agent_ready.set()
            exc_info = sys.exc_info()
            logger.error(''.join(traceback.format_exception(*exc_info)))
            raise e
        # If error was only happening on one of the threads, all of them should exit
        if mailbox.agent_error.is_set():
            return


def model_loop_fn(*args):
    """Spawn processes in the TPUs forwarding arguments"""
    xmp.spawn(_mp_fn, args=(args), join=True, daemon=False)


class TpuGenerator(Generator):
    """A Generator for models running on TPU.

    This generator actually spawns several processes to handle the requests in sharded models whenever possible.
    """

    def __init__(self, model_path: str, revision: str, max_batch_size: int, max_sequence_length: int):
        manager = mp.Manager()
        self.mailbox = RootMailbox(manager)

        # Disable parallelism on tokenizers to avoid deadlocks on TPU threads
        os.environ["TOKENIZERS_PARALLELISM"] = "false"

        self.model_loop = mp.Process(
            target=model_loop_fn, args=(model_path, revision, max_batch_size, max_sequence_length, self.mailbox)
        )
        self.model_loop.start()

    @property
    def info(self) -> InfoResponse:
        s_info = self.mailbox.send(GeneratorCommand.INFO, None)[0]
        return InfoResponse.FromString(s_info)

    def warmup(self, batch: Batch) -> int:
        return self.mailbox.send(GeneratorCommand.WARMUP, batch.SerializeToString())[0]

    def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
        s_generations, s_cached_batch = self.mailbox.send(GeneratorCommand.PREFILL, batch.SerializeToString())
        generations = [Generation.FromString(g) for g in s_generations]
        cached_batch = CachedBatch.FromString(s_cached_batch) if s_cached_batch is not None else None
        return generations, cached_batch

    def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
        s_batches = [b.SerializeToString() for b in batches]
        s_generations, s_cached_batch = self.mailbox.send(GeneratorCommand.DECODE, s_batches)
        generations = [Generation.FromString(g) for g in s_generations]
        cached_batch = CachedBatch.FromString(s_cached_batch) if s_cached_batch is not None else None
        return generations, cached_batch

    def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
        s_cached_batch = self.mailbox.send(GeneratorCommand.FILTER, batch_id, request_ids)[0]
        return CachedBatch.FromString(s_cached_batch)

    def clear(self, batch_id: Optional[int] = None):
        self.mailbox.send(GeneratorCommand.CLEAR, batch_id)

    def leave(self):
        if self.mailbox is None:
            return
        self.mailbox.send(GeneratorCommand.DELETE)
        # Use Loguru's logger directly, to avoid errors whyle TPU is shutting down
        logger.logger.debug("Joining...")
        self.model_loop.join()
        logger.logger.debug("Generator loop finished")
        self.mailbox = None

    @property
    def config(self):
        return self.mailbox.config

    def __del__(self):
        self.leave()

    @classmethod
    def from_pretrained(cls, model_path: str, revision: str, max_batch_size: int, max_sequence_length: int):
        """Instantiate a Generator distributed on as many cores as possible.

        Args:
            model_path (`str`):
                The path to a local model. This path must also contain a Tokenizer.
            revision (`str`):
                The revision of the model.
            max_batch_size (`int`):
                The maximum batch size.
            max_sequence_length (`int`):
                The maximum sequence length.

        Returns:
            A TpuGenerator.
        """
        return cls(model_path, revision, max_batch_size, max_sequence_length)
