in text-generation-inference/server/text_generation_server/generator.py [0:0]
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill new requests.
Args:
batch (`Batch`):
A batch containing the new requests.
Return:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""
slots = {state: [] for state in Slot.State}
for slot in self.slots:
slots[slot.state].append(slot)
active_slots = slots[Slot.State.READY]
# Delete all empty slots, no need to have them anymore
empty_slots = slots[Slot.State.EMPTY]
model_batch_size = self.model.config.batch_size
if model_batch_size is not None and model_batch_size < len(active_slots) + len(batch.requests):
# If raising an error here wouldn't crash the server, we could raise a ValueError
error = ValueError(
f"Cannot prefill {len(batch.requests)} new request(s)."
f" Maximum batch size supported is: {model_batch_size}."
)
# but since it's not possible, we just log the error and return an empty generation
logger.error(error)
return [], None
for slot in empty_slots:
self.slots.remove(slot)
# Assign each request to an empty slot
logger.debug(f"Prefilling {len(batch.requests)} new request(s) adding to {len(active_slots)} active slot(s)")
for request in batch.requests:
# Dynamically create a new slot for each request
slot = Slot(self.slot_index, self.tokenizer, self.model.device)
self.slot_index += 1
slot.assign(self.batch_id, request, self.model.generation_config)
self.slots.append(slot)
logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
logger.debug(
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
)
# Reconstruct the full inputs (without padding) as seen by the model.
# This comprises:
# - the inputs for new requests,
# - the inputs and the generated text that has already been cached (i.e. excluding the last generated token)
# for unfinished requests.
# Prepare inputs. They need to be tokenized and truncated afterwards.
max_len = 0
batch_inputs = []
for slot in self.slots:
batch_inputs.append(slot.cached_text)
max_len = max(max_len, slot.truncate)
if max_len == 0:
max_len = self.model.config.sequence_length
tokenized_inputs = self.tokenizer(batch_inputs,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_len)
seq_length = tokenized_inputs.input_ids.size(-1)
seq_length = min(seq_length, self.model.config.sequence_length)
batch_size = len(self.slots)
# Initialize input_ids and attention_mask with padding (to make them all the same size)
input_ids = torch.full((batch_size, seq_length), self.tokenizer.pad_token_id, dtype=torch.int64)
attention_mask = torch.full((batch_size, seq_length), 0, dtype=torch.int64)
# Pause previously active slots during generation and store their last token.
next_tokens = []
for slot in active_slots:
next_tokens.append(slot.next_token)
slot.pause()
# Each slot must be reset with the padded inputs and masks
for i, slot in enumerate(self.slots):
assert slot.state != slot.state.EMPTY
truncation = min(tokenized_inputs.input_ids.size(-1), input_ids.size(-1))
if slot.truncate > 0:
truncation = min(truncation, slot.truncate)
input_ids[i, -truncation:] = tokenized_inputs.input_ids[i, -truncation:]
slot_input_ids = input_ids[i : i + 1, :]
# Padded input ids are also required to set logits processors and stopping criterias
try:
selector = TokenSelector.create(
slot_input_ids,
slot.generation_config,
self.model,
self.model.config.sequence_length,
seed=slot.seed,
)
except ValueError as e:
# This is very unlikely, but it seems it could be possible if router does not check values beforehand.
# In that case, we just skip the slot, and mark it as empty. This should prevent returning this to the
# router.
logger.error(f"Invalid generation parameters for slot {slot.id}. Skipping it. Error: {e}")
slot.clear()
continue
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
attention_mask[i, -truncation:] = tokenized_inputs.attention_mask[i, -truncation:]
if self._supports_static_cache:
# Attention mask does not need to be tracked when using static cache
slot_attention_mask = None
else:
slot_attention_mask = attention_mask[i]
slot.reset(slot_input_ids, slot_attention_mask, selector)
# Clear KV cache
self.past_key_values = None
# Obtain position ids using attention mask.
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
# Save position id for every slot
for slot, position_id in zip(self.slots, position_ids):
slot.position_id = position_id.max().item() + 1
extra_args = {}
if self._supports_static_cache:
self.past_key_values = StaticCacheXla(
config=self.model.config,
max_batch_size=len(self.slots),
max_cache_len=self.model.config.sequence_length,
device=self.model.device,
dtype=self.model.dtype,
)
extra_args["cache_position"] = torch.arange(seq_length, device=self.model.device)
extra_args["past_key_values"] = self.past_key_values
else:
# Reset/clear KV cache
self.past_key_values = None
generation, next_batch = self._generate_token(
self.batch_id,
input_ids.to(self.model.device),
self.model,
attention_mask=attention_mask.to(self.model.device),
position_ids=position_ids.to(self.model.device),
**extra_args,
)
self.batch_id += 1
# Reactivate previously active slots for the next decode, and append
# back their next token.
for slot, next_token in zip(active_slots, next_tokens):
slot.append(next_token)
slot.resume()
logger.debug("Model ready for decoding")
if next_batch is not None:
logger.debug(f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}")
return generation, next_batch