def prefill()

in text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py [0:0]


    def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
        """Prefill new requests.

        Args:
            batch (`Batch`):
                A batch containing the new requests.

        Return:
            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
        """

        active_slots = [slot for slot in self.slots if slot.state == Slot.State.READY]
        len_active_slots = len(active_slots)

        len_requests = len(batch.requests)
        model_batch_size = self.model.config.batch_size
        if model_batch_size is not None and model_batch_size < len_active_slots + len_requests:
            # If raising an error here wouldn't crash the server, we could raise a ValueError
            error = ValueError(
                f"Cannot prefill {len_requests} new request(s)."
                f" Maximum batch size supported is: {model_batch_size}."
            )
            # but since it's not possible, we just log the error and return an empty generation
            logger.error(error)
            return [], None
        # Assign each request to an empty slot
        logger.debug(f"Prefilling {len_requests} new request(s) adding to {len_active_slots} active slot(s)")
        generations = []
        prefilled_active_slots = []
        for request in batch.requests:
            # Dynamically create a new slot for each request
            slot = self._get_slot()
            self.prefill_slot.set(slot)
            self.slot_index += 1
            slot.assign(self.batch_id, request, self.model.generation_config)
            logger.debug(
                f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
            )

            # Tokenize the inputs
            input_ids, true_lengths = self._token_encode(request.inputs, slot.truncate)
            truncated_input_ids = torch_xla2.tensor.j2t(input_ids[:true_lengths]).to(torch.int64)
            selector = TokenSelector.create(
                truncated_input_ids,
                slot.generation_config,
                self.model,
                self.model.config.sequence_length,
                seed=slot.seed,
            )
            slot.reset(truncated_input_ids, selector)
            # To allow jit'ing the select function, we need to wrap it in a partial
            slot_select = jax.tree_util.Partial(self.prefill_slot.select)
            # Ask for prefill and insert
            prefill_results, _result_tokens = self.engine.prefill_ex(
                params=self.params,
                padded_tokens=input_ids,
                true_length=true_lengths,
                sampler=slot_select,
            )
            next_token = prefill_results.token.item()
            self.decode_state = self.engine.insert(prefill_results, self.decode_state, slot.id)

            self._post_generate(slot, next_token, generations)
            if not slot.empty:
                prefilled_active_slots.append(slot)

        cached_batch = self._cached_batch(self.batch_id, prefilled_active_slots)
        self.batch_id += 1
        logger.debug("Model ready for decoding")
        if cached_batch is not None:
            logger.debug(f"Next batch is {cached_batch.id} with requests: {cached_batch.request_ids}")
        return generations, cached_batch