def decode()

in text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py [0:0]


    def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
        """Decode the specified prefilled requests.

        Args:
            batches (`List[CachedBatch]`):
                A list of previous batches containing the prefilled requests.

        Return:
            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
        """

        # In python we should use type duck, but if elements passed on the list are not of the right type, this will
        # prevent raising an error and wasting time. Return an empty generation instead.
        if any(not isinstance(item, CachedBatch) for item in batches):
            logger.error("Unexpected type in decode, expected CachedBatch")
            return [], None

        # batches contains a list composed of ongoing requests:
        # - the batch id returned by the last decode,
        # - the batch id(s) returned by the last prefill(s)
        # Batches are always concatenated during prefill, so we can
        # just carry on with decoding. We adopt the id of the first
        # batch in the list as our next batch id.
        next_batch_id = batches[0].id
        if len(batches) > 1:
            logger.warning("Unexpected multiple batches received, only the first one will be processed.")
        request_ids = []
        for batch in batches:
            request_ids += batch.request_ids
        cleared_request_ids = []
        for slot in self.slots:
            if slot.state == slot.State.READY and slot.request_id not in request_ids:
                cleared_request_ids.append(slot.request_id)
                self.slots.remove(slot)
        if len(cleared_request_ids) > 0:
            logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.")
        active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
        if len(active_slots) < len(request_ids):
            raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")

        # Use a custom function to select the next token for each slot
        self.decode_state, result_tokens = self.engine.generate_impl(self.params, self.decode_state, self._select_from_slots)

        generations = []
        for slot in active_slots:
            # Get the next token.
            # Note that for now we ignore is_valid and length as we don't use them, we will re-parse these in post
            # generation.
            next_token = self.decode_state.tokens[slot.id].item()

            if slot.state != Slot.State.READY:
                logger.error(f"Unexpected Slot {slot.id} is not ready for decoding, skipping.")
                raise ValueError("Unexpected Slot is not ready for decoding")

            self._post_generate(slot, next_token, generations)

        cached_batch = self._cached_batch(next_batch_id, active_slots)
        return generations, cached_batch