def prefill()

in backends/neuron/server/text_generation_server/generator.py [0:0]


    def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
        """Prefill new requests.

        Args:
            batch (`Batch`):
                A batch containing the new requests.

        Return:
            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
        """
        slots = {state: [] for state in Slot.State}
        for slot in self.slots:
            slots[slot.state].append(slot)
        active_slots = slots[Slot.State.READY]
        empty_slots = slots[Slot.State.EMPTY]
        if len(empty_slots) < len(batch.requests):
            raise ValueError(
                f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
                f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
            )
        # Assign each request to an empty slot
        logger.debug(
            f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)"
        )
        new_slots = []
        for request in batch.requests:
            slot = empty_slots.pop()
            slot.assign(self.batch_id, request, self.model.generation_config)
            new_slots.append(slot)
            logger.debug(
                f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
            )
        prefill_slots = new_slots
        seq_ids = torch.tensor([slot.id for slot in prefill_slots])
        # Reconstruct the full inputs (without padding) as seen by the model.
        # This comprises:
        # - the inputs for new requests,
        # - only when rebuilding the cache, the inputs and the generated text that has already
        # been cached (i.e. excluding the last generated token) for unfinished requests.
        inputs = []
        max_length = 0
        for slot in prefill_slots:
            inputs.append(slot.cached_text)
            # Apply truncation, making sure we fit into static dimensions
            if slot.truncate == 0:
                max_length = self.max_prefill_length()
            elif (
                slot.truncate > max_length and slot.truncate < self.max_prefill_length()
            ):
                max_length = slot.truncate
        # Tokenize with padding and truncation
        padded_inputs = self.tokenizer(
            inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        )
        input_ids = padded_inputs.input_ids
        attention_mask = padded_inputs.attention_mask
        sampling_params = (
            torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None
        )
        # Pause previously active slots during generation
        for slot in active_slots:
            slot.pause()
        # Each slot must be reset with the padded inputs and masks
        for i, slot in enumerate(prefill_slots):
            if slot.state != slot.state.EMPTY:
                if slot.truncate > 0 and slot.truncate < input_ids.shape[-1]:
                    # Apply per-request truncation
                    input_ids[i, : -slot.truncate] = self.tokenizer.pad_token_id
                    attention_mask[i, : -slot.truncate] = 0
                slot_input_ids = input_ids[i : i + 1, :]
                # Padded input ids are also required to set logits processors and stopping criterias
                selector = TokenSelector.create(
                    slot_input_ids,
                    slot.generation_config,
                    self.model,
                    self.model.neuron_config.sequence_length,
                    tokenizer=self.tokenizer,
                    seed=slot.seed,
                )
                slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
                slot_attention_mask = attention_mask[i]
                slot.reset(slot_input_ids, slot_attention_mask, selector)
                if sampling_params is not None:
                    sampling_params[i, 0] = slot.generation_config.top_k
                    sampling_params[i, 1] = slot.generation_config.top_p
                    sampling_params[i, 2] = slot.generation_config.temperature
        # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
        # as they have already been generated and sent back in the last decode.
        model_inputs = self.model.prepare_inputs_for_prefill(
            input_ids,
            attention_mask=attention_mask,
            seq_ids=seq_ids,
            sampling_params=sampling_params,
        )
        tokens_or_logits = self.model(**model_inputs)[0]
        generation, next_batch = self._generate_token(
            prefill_slots, self.batch_id, tokens_or_logits, input_ids
        )
        self.batch_id += 1
        # Reactivate previously active slots for the next decode
        for i, slot in enumerate(active_slots):
            slot.resume()
        logger.debug("Model ready for decoding")
        if next_batch is not None:
            logger.debug(
                f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}"
            )
        return generation, next_batch