in text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py [0:0]
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill new requests.
Args:
batch (`Batch`):
A batch containing the new requests.
Return:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""
active_slots = [slot for slot in self.slots if slot.state == Slot.State.READY]
len_active_slots = len(active_slots)
len_requests = len(batch.requests)
model_batch_size = self.model.config.batch_size
if model_batch_size is not None and model_batch_size < len_active_slots + len_requests:
# If raising an error here wouldn't crash the server, we could raise a ValueError
error = ValueError(
f"Cannot prefill {len_requests} new request(s)."
f" Maximum batch size supported is: {model_batch_size}."
)
# but since it's not possible, we just log the error and return an empty generation
logger.error(error)
return [], None
# Assign each request to an empty slot
logger.debug(f"Prefilling {len_requests} new request(s) adding to {len_active_slots} active slot(s)")
generations = []
prefilled_active_slots = []
for request in batch.requests:
# Dynamically create a new slot for each request
slot = self._get_slot()
self.prefill_slot.set(slot)
self.slot_index += 1
slot.assign(self.batch_id, request, self.model.generation_config)
logger.debug(
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
)
# Tokenize the inputs
input_ids, true_lengths = self._token_encode(request.inputs, slot.truncate)
truncated_input_ids = torch_xla2.tensor.j2t(input_ids[:true_lengths]).to(torch.int64)
selector = TokenSelector.create(
truncated_input_ids,
slot.generation_config,
self.model,
self.model.config.sequence_length,
seed=slot.seed,
)
slot.reset(truncated_input_ids, selector)
# To allow jit'ing the select function, we need to wrap it in a partial
slot_select = jax.tree_util.Partial(self.prefill_slot.select)
# Ask for prefill and insert
prefill_results, _result_tokens = self.engine.prefill_ex(
params=self.params,
padded_tokens=input_ids,
true_length=true_lengths,
sampler=slot_select,
)
next_token = prefill_results.token.item()
self.decode_state = self.engine.insert(prefill_results, self.decode_state, slot.id)
self._post_generate(slot, next_token, generations)
if not slot.empty:
prefilled_active_slots.append(slot)
cached_batch = self._cached_batch(self.batch_id, prefilled_active_slots)
self.batch_id += 1
logger.debug("Model ready for decoding")
if cached_batch is not None:
logger.debug(f"Next batch is {cached_batch.id} with requests: {cached_batch.request_ids}")
return generations, cached_batch