in backends/neuron/server/text_generation_server/generator.py [0:0]
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill new requests.
Args:
batch (`Batch`):
A batch containing the new requests.
Return:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""
slots = {state: [] for state in Slot.State}
for slot in self.slots:
slots[slot.state].append(slot)
active_slots = slots[Slot.State.READY]
empty_slots = slots[Slot.State.EMPTY]
if len(empty_slots) < len(batch.requests):
raise ValueError(
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
)
# Assign each request to an empty slot
logger.debug(
f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)"
)
new_slots = []
for request in batch.requests:
slot = empty_slots.pop()
slot.assign(self.batch_id, request, self.model.generation_config)
new_slots.append(slot)
logger.debug(
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
)
prefill_slots = new_slots
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
# Reconstruct the full inputs (without padding) as seen by the model.
# This comprises:
# - the inputs for new requests,
# - only when rebuilding the cache, the inputs and the generated text that has already
# been cached (i.e. excluding the last generated token) for unfinished requests.
inputs = []
max_length = 0
for slot in prefill_slots:
inputs.append(slot.cached_text)
# Apply truncation, making sure we fit into static dimensions
if slot.truncate == 0:
max_length = self.max_prefill_length()
elif (
slot.truncate > max_length and slot.truncate < self.max_prefill_length()
):
max_length = slot.truncate
# Tokenize with padding and truncation
padded_inputs = self.tokenizer(
inputs,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
input_ids = padded_inputs.input_ids
attention_mask = padded_inputs.attention_mask
sampling_params = (
torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None
)
# Pause previously active slots during generation
for slot in active_slots:
slot.pause()
# Each slot must be reset with the padded inputs and masks
for i, slot in enumerate(prefill_slots):
if slot.state != slot.state.EMPTY:
if slot.truncate > 0 and slot.truncate < input_ids.shape[-1]:
# Apply per-request truncation
input_ids[i, : -slot.truncate] = self.tokenizer.pad_token_id
attention_mask[i, : -slot.truncate] = 0
slot_input_ids = input_ids[i : i + 1, :]
# Padded input ids are also required to set logits processors and stopping criterias
selector = TokenSelector.create(
slot_input_ids,
slot.generation_config,
self.model,
self.model.neuron_config.sequence_length,
tokenizer=self.tokenizer,
seed=slot.seed,
)
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
slot_attention_mask = attention_mask[i]
slot.reset(slot_input_ids, slot_attention_mask, selector)
if sampling_params is not None:
sampling_params[i, 0] = slot.generation_config.top_k
sampling_params[i, 1] = slot.generation_config.top_p
sampling_params[i, 2] = slot.generation_config.temperature
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
# as they have already been generated and sent back in the last decode.
model_inputs = self.model.prepare_inputs_for_prefill(
input_ids,
attention_mask=attention_mask,
seq_ids=seq_ids,
sampling_params=sampling_params,
)
tokens_or_logits = self.model(**model_inputs)[0]
generation, next_batch = self._generate_token(
prefill_slots, self.batch_id, tokens_or_logits, input_ids
)
self.batch_id += 1
# Reactivate previously active slots for the next decode
for i, slot in enumerate(active_slots):
slot.resume()
logger.debug("Model ready for decoding")
if next_batch is not None:
logger.debug(
f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}"
)
return generation, next_batch