text-generation-inference/server/text_generation_server/generator.py (681 lines of code) (raw):
import copy
import logging
import os
import sys
import time
import traceback
from bisect import bisect_left
from enum import Enum
from typing import Dict, List, Optional, Tuple
import torch
import torch.multiprocessing as mp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from transformers.generation import GenerationConfig
import optimum.tpu.xla_logger as logger
from optimum.tpu.generation import TokenSelector
from optimum.tpu.modeling import AutoModelForCausalLM
from optimum.tpu.static_cache_xla import StaticCacheXla
from optimum.tpu.xla_mp_comm import AgentMailbox, RootMailbox
from .generator_base import Generator
from .pb.generate_pb2 import (
Batch,
CachedBatch,
FinishReason,
GeneratedText,
Generation,
InfoResponse,
NextTokenChooserParameters,
Request,
StoppingCriteriaParameters,
Tokens,
)
# Disable optimum-tpu warnings as it seems to block the server after a while
optimum_logger = logging.getLogger("optimum.tpu")
optimum_logger.setLevel("CRITICAL")
# These will do some bucketing on prefill lengths to avoid too many different sizes
PREFILL_LENGTHS = list(range(6, 16)) + [
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
32768,
]
def take_nearest_length(length: int) -> int:
"""Gets the nearest length to the right in a set of lengths."""
pos = bisect_left(PREFILL_LENGTHS, length)
if pos == len(PREFILL_LENGTHS):
return PREFILL_LENGTHS[-1]
return PREFILL_LENGTHS[pos]
class Slot:
"""Represents a slot in a static batch"""
class State(Enum):
EMPTY = 0
PAUSE = 1
READY = 2
def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase, device: [str, torch.device]):
self._id = id
self._tokenizer = tokenizer
self.clear()
self._device = device
def clear(self):
"""Clear the slot and mark it as available."""
self._state = Slot.State.EMPTY
self._batch_id = None
self._request_id = None
self._inputs = ""
self._generation_config = None
self._tokens = []
self._mask = None
self._selector = None
self._generated_tokens = 0
self._next_text_token_start = 0
self._next_text_token_end = 0
self._generated_text = ""
self._next_text = ""
self._kv_cache = None
self._truncate = 0
self._position_id = 0
@property
def id(self) -> int:
return self._id
@property
def state(self) -> "Slot.State":
return self._state
@property
def batch_id(self) -> int:
return self._batch_id
@property
def request_id(self) -> int:
return self._request_id
@property
def cached_text(self) -> str:
return self._inputs + self._generated_text
@property
def generation_config(self) -> GenerationConfig:
return self._generation_config
@property
def generated_tokens(self) -> int:
return self._generated_tokens
@property
def truncate(self) -> int:
return self._truncate
@property
def position_id(self) -> int:
return self._position_id
@position_id.setter
def position_id(self, cur_pos: int):
self._position_id = cur_pos
@property
def cache_position(self) -> int:
# This corresponds to the cache position for this slot
return self._next_text_token_start
def assign(self, batch_id: int, request: Request, generation_config: GenerationConfig):
"""Assign a request to a slot.
Args:
batch_id (`int`): The id of the batch containing the request.
request (`Request`):
The request to be assigned. Contains the inputs and tokens selection parameters.
generation_config (`transformers.GenerationConfig`):
The base generation config (might be modified by the request generation parameters).
"""
self._state = Slot.State.READY
self._batch_id = batch_id
self._request_id = request.id
self._inputs = request.inputs
self._generation_config = copy.deepcopy(generation_config)
# Update generation config with token chooser parameters
self._generation_config.temperature = request.parameters.temperature
self._generation_config.top_k = request.parameters.top_k
self._generation_config.top_p = request.parameters.top_p
self._generation_config.typical_p = request.parameters.typical_p
self._generation_config.do_sample = request.parameters.do_sample
self._generation_config.repetition_penalty = request.parameters.repetition_penalty
self._truncate = request.truncate
self.seed = request.parameters.seed
# TODO: watermark
self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens
self._max_new_tokens = self._generation_config.max_new_tokens
# TODO: stop_sequences, ignore_eos_token
def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor = None, selector: TokenSelector = None):
"""Reset the slot for the next generation.
Args:
input_ids: (`torch.LongTensor`):
The new input_ids to use to generate the next token.
attention_mask: (`torch.LongTensor`):
The new attention_mask to use to generate the next token.
selector: (`TokenSelector`):
An object implementing the updated token selection logic.
"""
self._tokens = input_ids.cpu()
self._next_text_token_start = 0
self._next_text_token_end = torch.numel(self._tokens)
self._next_text = ""
if attention_mask is not None:
self._mask = attention_mask.clone()
else:
self._mask = None
self._selector = selector
def pause(self):
"""Mark the current slot as paused for generation.
Note that the KV cache for this slot will still be filled.
"""
# Drop the last token as it will be added back when resuming the slot
self._generated_tokens -= 1
# Since generated tokens are now part of the prefill, we need to reevaluate
# max_new_tokens for the next generation
self._generation_config.max_new_tokens = self._max_new_tokens - self._generated_tokens
self._state = Slot.State.PAUSE
def resume(self):
"""Mark the slot as ready for generation."""
self._state = Slot.State.READY
def _decode_next_tokens(
self,
) -> str:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# Copy the tokens to CPU to avoid recompilation on TPU. Post-processing is quite fast anyway.
tokens = self._tokens.cpu()
# We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
new_text = self._tokenizer.decode(tokens[self._next_text_token_start :], skip_special_tokens=False)
if new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
return ""
# Compare the generated text with the one using only the tokens producing the last one
last_text = self._tokenizer.decode(
tokens[self._next_text_token_start : self._next_text_token_end],
skip_special_tokens=False,
)
if len(new_text) == len(last_text):
# Nothing new was actually generated
return ""
# Return the decoded text and store its token offsets
self._next_text_token_start = self._next_text_token_end
self._next_text_token_end = torch.numel(tokens)
return new_text[len(last_text) :]
def append(self, next_token: int) -> str:
"""Append a new generated token to this slot
The new token is added to the list of generated tokens, which impacts
directly the generated_text and stopped property.
The new token is however not added immediately to the slot inputs: it will
be added later on when it has effectively been used to produce the next token.
Args:
next_token (`int`):
The newly generated token.
Return:
The corresponding decoded text (if any).
"""
self._tokens = torch.cat([self._tokens, torch.tensor([next_token], dtype=self._tokens.dtype)])
# Update mask only if it was set previously
if self._mask is not None:
self._mask = torch.cat([self._mask, torch.tensor([1], dtype=self._mask.dtype)])
self._generated_tokens += 1
next_text = self._decode_next_tokens()
# Now that a new token has been generated, we can append the previous one to the generated text
self._generated_text += self._next_text
self._next_text = next_text
return next_text
def select(self, input_ids: torch.LongTensor, logits: torch.Tensor) -> torch.LongTensor:
"""Select the next token from the candidate logits.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation (not used in all generation modes).
logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
The logits corresponding to the generated tokens.
Return:
`torch.LongTensor`: A scalar torch.LongTensor` containing the selected token.
"""
return self._selector.select(input_ids, logits)[0]
@property
def stopped(self) -> bool:
# unsqueeze tokens to avoid problems with stopping criteria
tokens = self._tokens.unsqueeze(0)
return bool(torch.all(self._selector.stopping_criteria(tokens, None)))
@property
def generated_text(self) -> str:
return self._generated_text + self._next_text
@property
def next_token(self) -> int:
return None if len(self._tokens) == 0 else self._tokens[-1]
@property
def attention_mask(self) -> torch.LongTensor:
return self._mask
@property
def max_token(self) -> int:
return self._generation_config.max_length
@property
def max_new_tokens(self) -> int:
# The current value of max_new_tokens: might be different of the target max_new_tokens
# if the slot has been paused and resumed.
return self._generation_config.max_new_tokens
class TpuGeneratorSingleThread(Generator):
"""A Generator for models running on TPU, single threaded."""
def __init__(
self,
model,
tokenizer: PreTrainedTokenizerBase,
):
self.model = model
# Specify padding options for decoder-only architecture
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
# The token selector will use the model's generation mixin internal variables to select the next token, and it
# expects special tokens to be initialized in the model.
model._prepare_special_tokens(generation_config=model.generation_config, device=model.device)
# Slots are empty to begin with, they will be populated as new batches arrive
self.slots = []
self.batch_id = 0
# Note: this index will _never_ be decremented, and that's fine.
self.slot_index = 0
self.past_key_values = None
# _supports_static_cache is specific to some models (e.g.: Gemma and Llama).
self._supports_static_cache = True
if getattr(self.model, "_supports_static_cache", False) is False:
logger.warning(
f"Static cache not available for {self.model.__class__.__name__}. Performance will be affected"
)
self._supports_static_cache = False
# compile model when possible to accelerate decoding
if model.device.type == "xla" and ("DBG_COMPILE" in os.environ):
self.model_one_token = torch.compile(model, backend="openxla")
logger.debug("Model compiled for decoding")
else:
self.model_one_token = model
@property
def info(self) -> InfoResponse:
"""Returns the expected InfoResponse."""
dtype = getattr(self.model.config, "torch_dtype", "float32")
return InfoResponse(
requires_padding=True,
dtype=str(dtype),
device_type="xla",
)
def _create_dummy_request(self, max_tokens: int) -> Batch:
"""Create a dummy request for warmup."""
# Generate a random input with slightly more tokens than requested, because special tokens are going to be
# skipped.
MARGIN = 10
input_tokens = torch.randint(self.model.config.vocab_size, (1, max_tokens + MARGIN), dtype=torch.int64)
text = self.tokenizer.decode(input_tokens[0], skip_special_tokens=True)
# These are just dummy params to allow Request creation
parameters = NextTokenChooserParameters(
temperature=1.0,
top_k=None,
top_p=None,
do_sample=False,
seed=None,
repetition_penalty=1.0,
typical_p=1.0,
)
stopping_parameters = StoppingCriteriaParameters(max_new_tokens=20, ignore_eos_token=True)
dummy_request = Request(
id=0,
inputs=text,
truncate=max_tokens,
parameters=parameters,
stopping_parameters=stopping_parameters,
)
return dummy_request
def warmup(self, batch: Batch) -> int:
"""Verify if the hardware can support the target load.
Args:
batch (`Batch`):
A batch corresponding to the maximum number of concurrent requests.
Return:
The maximum number of tokens the model supports.
"""
logger.debug("Warming up the model")
start = time.time()
# Just check that the warmup request parameters match the model capacity
# NOTE: later self.model.config.batch_size might become self.model.config.max_batch_size.
if self.model.config.batch_size is not None:
batch_size = self.model.config.batch_size
else:
# batch size is not set, just assume it's unlimited and accept all requests
batch_size = len(batch.requests)
if len(batch.requests) > batch_size:
raise ValueError(
f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length."
)
# Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible
# batch sizes and sequence lengths.
seq_len = self.model.config.sequence_length
if os.environ.get("SKIP_WARMUP", "0") == "1":
logger.debug("Skipping warmup")
return batch_size * seq_len
bucket_seq_len = take_nearest_length(seq_len)
requests = [self._create_dummy_request(seq_len) for _ in range(batch_size)]
for _ in reversed(range(batch_size)):
# Prefill with different truncate sizes to test all prefill lengths. List is reversed so first longest
# sequences are tested and, if there is a memory failure, that will appear sooner.
for l in reversed(PREFILL_LENGTHS):
# Skip all the unsupported lengths
if l > bucket_seq_len:
continue
# Set all truncate values for all requests
for r in requests:
r.truncate = l
r.stopping_parameters.max_new_tokens = 10
warmup_batch = Batch(id=0,
requests=requests,
size=len(requests),
max_tokens=batch.max_tokens)
logger.debug(f"Warmup for {len(requests)} requests, truncate value {l} seq_len {seq_len}")
_generations, next_batch = self.prefill(warmup_batch)
if next_batch is not None:
self.decode([next_batch])
else:
logger.debug(f"No decode on warmup for {len(requests)}x{l}")
self.clear()
# remove the last requests to decrease the batch size
requests.pop()
elapsed = time.time() - start
logger.debug(f"Warmup done, took {elapsed:.2f}s")
return batch_size * seq_len
@torch.no_grad
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill new requests.
Args:
batch (`Batch`):
A batch containing the new requests.
Return:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""
slots = {state: [] for state in Slot.State}
for slot in self.slots:
slots[slot.state].append(slot)
active_slots = slots[Slot.State.READY]
# Delete all empty slots, no need to have them anymore
empty_slots = slots[Slot.State.EMPTY]
model_batch_size = self.model.config.batch_size
if model_batch_size is not None and model_batch_size < len(active_slots) + len(batch.requests):
# If raising an error here wouldn't crash the server, we could raise a ValueError
error = ValueError(
f"Cannot prefill {len(batch.requests)} new request(s)."
f" Maximum batch size supported is: {model_batch_size}."
)
# but since it's not possible, we just log the error and return an empty generation
logger.error(error)
return [], None
for slot in empty_slots:
self.slots.remove(slot)
# Assign each request to an empty slot
logger.debug(f"Prefilling {len(batch.requests)} new request(s) adding to {len(active_slots)} active slot(s)")
for request in batch.requests:
# Dynamically create a new slot for each request
slot = Slot(self.slot_index, self.tokenizer, self.model.device)
self.slot_index += 1
slot.assign(self.batch_id, request, self.model.generation_config)
self.slots.append(slot)
logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
logger.debug(
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
)
# Reconstruct the full inputs (without padding) as seen by the model.
# This comprises:
# - the inputs for new requests,
# - the inputs and the generated text that has already been cached (i.e. excluding the last generated token)
# for unfinished requests.
# Prepare inputs. They need to be tokenized and truncated afterwards.
max_len = 0
batch_inputs = []
for slot in self.slots:
batch_inputs.append(slot.cached_text)
max_len = max(max_len, slot.truncate)
if max_len == 0:
max_len = self.model.config.sequence_length
tokenized_inputs = self.tokenizer(batch_inputs,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_len)
seq_length = tokenized_inputs.input_ids.size(-1)
seq_length = min(seq_length, self.model.config.sequence_length)
batch_size = len(self.slots)
# Initialize input_ids and attention_mask with padding (to make them all the same size)
input_ids = torch.full((batch_size, seq_length), self.tokenizer.pad_token_id, dtype=torch.int64)
attention_mask = torch.full((batch_size, seq_length), 0, dtype=torch.int64)
# Pause previously active slots during generation and store their last token.
next_tokens = []
for slot in active_slots:
next_tokens.append(slot.next_token)
slot.pause()
# Each slot must be reset with the padded inputs and masks
for i, slot in enumerate(self.slots):
assert slot.state != slot.state.EMPTY
truncation = min(tokenized_inputs.input_ids.size(-1), input_ids.size(-1))
if slot.truncate > 0:
truncation = min(truncation, slot.truncate)
input_ids[i, -truncation:] = tokenized_inputs.input_ids[i, -truncation:]
slot_input_ids = input_ids[i : i + 1, :]
# Padded input ids are also required to set logits processors and stopping criterias
try:
selector = TokenSelector.create(
slot_input_ids,
slot.generation_config,
self.model,
self.model.config.sequence_length,
seed=slot.seed,
)
except ValueError as e:
# This is very unlikely, but it seems it could be possible if router does not check values beforehand.
# In that case, we just skip the slot, and mark it as empty. This should prevent returning this to the
# router.
logger.error(f"Invalid generation parameters for slot {slot.id}. Skipping it. Error: {e}")
slot.clear()
continue
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
attention_mask[i, -truncation:] = tokenized_inputs.attention_mask[i, -truncation:]
if self._supports_static_cache:
# Attention mask does not need to be tracked when using static cache
slot_attention_mask = None
else:
slot_attention_mask = attention_mask[i]
slot.reset(slot_input_ids, slot_attention_mask, selector)
# Clear KV cache
self.past_key_values = None
# Obtain position ids using attention mask.
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
# Save position id for every slot
for slot, position_id in zip(self.slots, position_ids):
slot.position_id = position_id.max().item() + 1
extra_args = {}
if self._supports_static_cache:
self.past_key_values = StaticCacheXla(
config=self.model.config,
max_batch_size=len(self.slots),
max_cache_len=self.model.config.sequence_length,
device=self.model.device,
dtype=self.model.dtype,
)
extra_args["cache_position"] = torch.arange(seq_length, device=self.model.device)
extra_args["past_key_values"] = self.past_key_values
else:
# Reset/clear KV cache
self.past_key_values = None
generation, next_batch = self._generate_token(
self.batch_id,
input_ids.to(self.model.device),
self.model,
attention_mask=attention_mask.to(self.model.device),
position_ids=position_ids.to(self.model.device),
**extra_args,
)
self.batch_id += 1
# Reactivate previously active slots for the next decode, and append
# back their next token.
for slot, next_token in zip(active_slots, next_tokens):
slot.append(next_token)
slot.resume()
logger.debug("Model ready for decoding")
if next_batch is not None:
logger.debug(f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}")
return generation, next_batch
@torch.no_grad
def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
"""Decode the specified prefilled requests.
Args:
batches (`List[CachedBatch]`):
A list of previous batches containing the prefilled requests.
Return:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""
# batches contains a list composed of:
# - the batch id returned by the last decode,
# - the batch id(s) returned by the last prefill(s)
# Batches are always concatenated during prefill, so we can
# just carry on with decoding. We adopt the id of the first
# batch in the list as our next batch id.
next_batch_id = batches[0].id
request_ids = []
for batch in batches:
request_ids += batch.request_ids
cleared_request_ids = []
for slot in self.slots:
if slot.state == slot.State.READY and slot.request_id not in request_ids:
cleared_request_ids.append(slot.request_id)
slot.clear()
if len(cleared_request_ids) > 0:
logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.")
active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
if len(active_slots) < len(request_ids):
logger.error("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
# Reconstruct input_ids and attention_mask from slots
input_ids = None
attention_mask = None
batch_size = len(self.slots)
position_ids = torch.zeros(
[batch_size, 1],
dtype=torch.int64,
)
# init pad_token_id and input_ids
pad_token_id = self.tokenizer.pad_token_id
if pad_token_id is None:
if isinstance(self.tokenizer.eos_token_id, list):
pad_token_id = self.tokenizer.eos_token_id[0]
else:
pad_token_id = self.tokenizer.eos_token_id
# Create blank inputs covering all slots (even empty ones)
input_ids = torch.full(
[batch_size, 1],
fill_value=pad_token_id,
dtype=torch.int64,
)
cache_position = torch.zeros([1], dtype=torch.int64)
for i, slot in enumerate(self.slots):
if slot.state != Slot.State.EMPTY:
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids.index_put_([torch.tensor([i])], slot.next_token)
if not self._supports_static_cache:
# When using dynamic cache, the whole attention mask needs to be passed over to the model at each iteration.
if attention_mask is None:
# Create default mask covering all slots (even empty ones)
attention_mask = torch.zeros(
[batch_size, slot.attention_mask.size(-1)],
dtype=torch.int64,
)
attention_mask.index_put_([torch.tensor([i])], slot.attention_mask)
position_ids.index_put_([torch.tensor([i])], torch.tensor(slot.position_id))
cache_position = torch.maximum(cache_position, torch.tensor([slot.cache_position]))
if input_ids is None:
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
extra_args = {}
if self._supports_static_cache:
extra_args["cache_position"] = position_ids.max().unsqueeze(0).to(self.model.device)
else:
extra_args["attention_mask"] = attention_mask.to(self.model.device)
extra_args["past_key_values"] = self.past_key_values
generations, next_batch = self._generate_token(
next_batch_id,
input_ids.to(self.model.device),
self.model_one_token,
position_ids=position_ids.to(self.model.device),
**extra_args,
)
for slot, gen in zip(self.slots, generations):
slot.position_id += len(gen.tokens.ids)
return generations, next_batch
def _generate_token(
self, next_batch_id: int, input_ids: torch.LongTensor, model: torch.nn.Module, **forward_extra_params
) -> Tuple[List[Generation], CachedBatch]:
# Add barrier to allow next graph step to always be the same
xm.mark_step()
# Forward
outputs = model(
input_ids,
return_dict=True,
use_cache=True,
**forward_extra_params,
)
if not self._supports_static_cache:
# Save KV cache
self.past_key_values = outputs.past_key_values
# Barrier for XLA model
xm.mark_step()
ret = self._post_generate(outputs, next_batch_id, input_ids)
return ret
def _post_generate(
self, outputs: Dict, next_batch_id: int, input_ids: torch.LongTensor
) -> Tuple[List[Generation], CachedBatch]:
generations = []
active_slots = False
for i, slot in enumerate(self.slots):
if slot.state != Slot.State.READY:
continue
request_id = slot.request_id
next_token_logits = outputs.logits[i : i + 1, -1, :]
slot_input_ids = input_ids[i : i + 1, :]
next_token = slot.select(slot_input_ids, next_token_logits)
next_token_text = slot.append(next_token)
generated_text = None
finish_reason = None
if next_token == self.tokenizer.eos_token_id:
finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN
elif slot.stopped:
if slot.generated_tokens == slot.max_new_tokens:
finish_reason = FinishReason.FINISH_REASON_LENGTH
else:
finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE
if finish_reason is not None:
# We must include the generated text for each finished sequence in the response
generated_text = GeneratedText(
text=slot.generated_text, generated_tokens=slot.generated_tokens, finish_reason=finish_reason
)
logger.debug(f"Decode complete for request {request_id} with {slot.generated_tokens} tokens")
# This slot is now empty, it will be removed from the list of
# active slots once a new prefill is requested
slot.clear()
else:
active_slots = True
generations.append(
Generation(
request_id=request_id,
prefill_tokens=None,
tokens=Tokens(
ids=[next_token],
logprobs=[0],
texts=[next_token_text],
is_special=[next_token in self.special_tokens],
),
generated_text=generated_text,
)
)
batch = None
if active_slots:
# Whatever initial batch these requests came from, we always return all pending requests in a single batch
request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY]
batch = self._cached_batch(next_batch_id, request_ids)
else:
logger.debug("No more pending requests")
return generations, batch
def _cached_batch(self, batch_id: int, request_ids: List):
size = len(request_ids)
max_tokens = size * self.model.config.sequence_length
return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens)
def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch:
"""Remove requests that are not listed from the specified batch
Args:
batch_id (`int`):
The id of a cached batch.
request_ids(`List[int]`):
The list of requests that must be kept.
Return:
A `CachedBatch` containing the pending requests.
"""
keep_slot_ids = [slot.id for slot in self.slots if slot.request_id in keep_request_ids]
self._clear(keep_slot_ids)
return self._cached_batch(batch_id, keep_request_ids)
def clear(self, batch_id: Optional[int] = None):
"""Remove a subset or all requests from the generator"""
keep_ids = []
if batch_id is not None:
keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id]
return self._clear(keep_ids)
def _clear(self, keep_slot_ids: List):
for slot in self.slots:
if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids:
logger.debug(f"Removing slot {slot.id} with request {slot.request_id}")
slot.clear()
@classmethod
def from_pretrained(cls, model_path: str, revision: str, max_batch_size: int, max_sequence_length: int):
"""Instantiate a TpuGenerator.
Args:
model_path (`str`):
The path to a local model. This path must also contain a Tokenizer.
revision (`str`):
The revision of the model.
max_batch_size (`int`):
The maximum batch size.
max_sequence_length (`int`):
The maximum sequence length.
Returns:
A TpuGenerator.
"""
logger.info("Loading model (this can take a few minutes).")
start = time.time()
model = AutoModelForCausalLM.from_pretrained(
model_path, revision=revision, batch_size=max_batch_size, sequence_length=max_sequence_length
)
end = time.time()
logger.info(f"Model successfully loaded in {end - start:.2f} s.")
tokenizer = AutoTokenizer.from_pretrained(model_path)
return cls(model, tokenizer)
class GeneratorCommand(Enum):
INFO = 0
WARMUP = 1
PREFILL = 2
DECODE = 3
FILTER = 4
CLEAR = 5
DELETE = -1
def _mp_fn(
rank, model_path: str, revision: str, max_batch_size: int, max_sequence_length: int, root_mailbox: RootMailbox
):
device = xm.xla_device()
world_size = xm.xrt_world_size()
# create agent mailbox out of root's one
mailbox = AgentMailbox(root_mailbox)
# re-init logger for each child process
logger_level = os.environ.get("LOGGER_LEVEL_GENERATOR", "DEBUG")
logger.logger.remove()
logger.logger.add(
sys.stdout,
filter="text_generation_server",
level=logger_level,
backtrace=True,
diagnose=False,
)
logger.debug(
f"Rank {rank} on {device} real device {xm.xla_real_devices([device])} ordinal {xm.get_ordinal()} "
+ f"world size {world_size}"
)
generator = TpuGeneratorSingleThread.from_pretrained(model_path, revision, max_batch_size, max_sequence_length)
# TODO: maybe model_config can be removed from mailbox
def return_to_caller(*data):
# consider adding a rendezvous here
if rank == 0:
xm.mark_step()
mailbox.send(*data)
while True:
xm.rendezvous("start")
if rank == 0:
mailbox.agent_ready.set()
mailbox.receive()
# Wait for rank 0 to receive command
xm.rendezvous("wait_command")
command, data = mailbox.command_data
logger.debug(f"Generator@{rank} {command.name}")
try:
if command == GeneratorCommand.INFO:
info = generator.info
return_to_caller(info.SerializeToString())
if command == GeneratorCommand.WARMUP:
batch = Batch.FromString(data[0])
return_to_caller(generator.warmup(batch=batch))
if command == GeneratorCommand.PREFILL:
batch = Batch.FromString(data[0])
generations, cached_batch = generator.prefill(batch=batch)
s_cached_batch = cached_batch.SerializeToString() if cached_batch is not None else None
return_to_caller([g.SerializeToString() for g in generations], s_cached_batch)
if command == GeneratorCommand.DECODE:
batches = [CachedBatch.FromString(b) for b in data[0]]
generations, cached_batch = generator.decode(batches=batches)
s_cached_batch = cached_batch.SerializeToString() if cached_batch is not None else None
return_to_caller([g.SerializeToString() for g in generations], s_cached_batch)
if command == GeneratorCommand.FILTER:
batch_id, request_ids = data
cached_batch = generator.filter(batch_id, request_ids)
return_to_caller(cached_batch.SerializeToString())
if command == GeneratorCommand.CLEAR:
batch_id = data[0]
generator.clear(batch_id)
if command == GeneratorCommand.DELETE:
if rank == 0:
# Set agent to ready
mailbox.agent_ready.set()
break
except Exception as e:
logger.error(f"Error in command {command.name}")
mailbox.agent_error.set()
mailbox.agent_ready.set()
exc_info = sys.exc_info()
logger.error(''.join(traceback.format_exception(*exc_info)))
raise e
# If error was only happening on one of the threads, all of them should exit
if mailbox.agent_error.is_set():
return
def model_loop_fn(*args):
"""Spawn processes in the TPUs forwarding arguments"""
xmp.spawn(_mp_fn, args=(args), join=True, daemon=False)
class TpuGenerator(Generator):
"""A Generator for models running on TPU.
This generator actually spawns several processes to handle the requests in sharded models whenever possible.
"""
def __init__(self, model_path: str, revision: str, max_batch_size: int, max_sequence_length: int):
manager = mp.Manager()
self.mailbox = RootMailbox(manager)
# Disable parallelism on tokenizers to avoid deadlocks on TPU threads
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.model_loop = mp.Process(
target=model_loop_fn, args=(model_path, revision, max_batch_size, max_sequence_length, self.mailbox)
)
self.model_loop.start()
@property
def info(self) -> InfoResponse:
s_info = self.mailbox.send(GeneratorCommand.INFO, None)[0]
return InfoResponse.FromString(s_info)
def warmup(self, batch: Batch) -> int:
return self.mailbox.send(GeneratorCommand.WARMUP, batch.SerializeToString())[0]
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
s_generations, s_cached_batch = self.mailbox.send(GeneratorCommand.PREFILL, batch.SerializeToString())
generations = [Generation.FromString(g) for g in s_generations]
cached_batch = CachedBatch.FromString(s_cached_batch) if s_cached_batch is not None else None
return generations, cached_batch
def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
s_batches = [b.SerializeToString() for b in batches]
s_generations, s_cached_batch = self.mailbox.send(GeneratorCommand.DECODE, s_batches)
generations = [Generation.FromString(g) for g in s_generations]
cached_batch = CachedBatch.FromString(s_cached_batch) if s_cached_batch is not None else None
return generations, cached_batch
def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
s_cached_batch = self.mailbox.send(GeneratorCommand.FILTER, batch_id, request_ids)[0]
return CachedBatch.FromString(s_cached_batch)
def clear(self, batch_id: Optional[int] = None):
self.mailbox.send(GeneratorCommand.CLEAR, batch_id)
def leave(self):
if self.mailbox is None:
return
self.mailbox.send(GeneratorCommand.DELETE)
# Use Loguru's logger directly, to avoid errors whyle TPU is shutting down
logger.logger.debug("Joining...")
self.model_loop.join()
logger.logger.debug("Generator loop finished")
self.mailbox = None
@property
def config(self):
return self.mailbox.config
def __del__(self):
self.leave()
@classmethod
def from_pretrained(cls, model_path: str, revision: str, max_batch_size: int, max_sequence_length: int):
"""Instantiate a Generator distributed on as many cores as possible.
Args:
model_path (`str`):
The path to a local model. This path must also contain a Tokenizer.
revision (`str`):
The revision of the model.
max_batch_size (`int`):
The maximum batch size.
max_sequence_length (`int`):
The maximum sequence length.
Returns:
A TpuGenerator.
"""
return cls(model_path, revision, max_batch_size, max_sequence_length)