in text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py [0:0]
def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
"""Decode the specified prefilled requests.
Args:
batches (`List[CachedBatch]`):
A list of previous batches containing the prefilled requests.
Return:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""
# In python we should use type duck, but if elements passed on the list are not of the right type, this will
# prevent raising an error and wasting time. Return an empty generation instead.
if any(not isinstance(item, CachedBatch) for item in batches):
logger.error("Unexpected type in decode, expected CachedBatch")
return [], None
# batches contains a list composed of ongoing requests:
# - the batch id returned by the last decode,
# - the batch id(s) returned by the last prefill(s)
# Batches are always concatenated during prefill, so we can
# just carry on with decoding. We adopt the id of the first
# batch in the list as our next batch id.
next_batch_id = batches[0].id
if len(batches) > 1:
logger.warning("Unexpected multiple batches received, only the first one will be processed.")
request_ids = []
for batch in batches:
request_ids += batch.request_ids
cleared_request_ids = []
for slot in self.slots:
if slot.state == slot.State.READY and slot.request_id not in request_ids:
cleared_request_ids.append(slot.request_id)
self.slots.remove(slot)
if len(cleared_request_ids) > 0:
logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.")
active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
if len(active_slots) < len(request_ids):
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
# Use a custom function to select the next token for each slot
self.decode_state, result_tokens = self.engine.generate_impl(self.params, self.decode_state, self._select_from_slots)
generations = []
for slot in active_slots:
# Get the next token.
# Note that for now we ignore is_valid and length as we don't use them, we will re-parse these in post
# generation.
next_token = self.decode_state.tokens[slot.id].item()
if slot.state != Slot.State.READY:
logger.error(f"Unexpected Slot {slot.id} is not ready for decoding, skipping.")
raise ValueError("Unexpected Slot is not ready for decoding")
self._post_generate(slot, next_token, generations)
cached_batch = self._cached_batch(next_batch_id, active_slots)
return generations, cached_batch