def prefill()

in text-generation-inference/server/text_generation_server/generator.py [0:0]


    def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
        """Prefill new requests.

        Args:
            batch (`Batch`):
                A batch containing the new requests.

        Return:
            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
        """
        slots = {state: [] for state in Slot.State}
        for slot in self.slots:
            slots[slot.state].append(slot)
        active_slots = slots[Slot.State.READY]
        # Delete all empty slots, no need to have them anymore
        empty_slots = slots[Slot.State.EMPTY]
        model_batch_size = self.model.config.batch_size
        if model_batch_size is not None and model_batch_size < len(active_slots) + len(batch.requests):
            # If raising an error here wouldn't crash the server, we could raise a ValueError
            error = ValueError(
                f"Cannot prefill {len(batch.requests)} new request(s)."
                f" Maximum batch size supported is: {model_batch_size}."
            )
            # but since it's not possible, we just log the error and return an empty generation
            logger.error(error)
            return [], None
        for slot in empty_slots:
            self.slots.remove(slot)
        # Assign each request to an empty slot
        logger.debug(f"Prefilling {len(batch.requests)} new request(s) adding to {len(active_slots)} active slot(s)")
        for request in batch.requests:
            # Dynamically create a new slot for each request
            slot = Slot(self.slot_index, self.tokenizer, self.model.device)
            self.slot_index += 1
            slot.assign(self.batch_id, request, self.model.generation_config)
            self.slots.append(slot)
            logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
            logger.debug(
                f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
            )
        # Reconstruct the full inputs (without padding) as seen by the model.
        # This comprises:
        # - the inputs for new requests,
        # - the inputs and the generated text that has already been cached (i.e. excluding the last generated token)
        #   for unfinished requests.

        # Prepare inputs. They need to be tokenized and truncated afterwards.
        max_len = 0
        batch_inputs = []
        for slot in self.slots:
            batch_inputs.append(slot.cached_text)
            max_len = max(max_len, slot.truncate)
        if max_len == 0:
            max_len = self.model.config.sequence_length
        tokenized_inputs = self.tokenizer(batch_inputs,
                                          return_tensors="pt",
                                          padding=True,
                                          truncation=True,
                                          max_length=max_len)
        seq_length = tokenized_inputs.input_ids.size(-1)
        seq_length = min(seq_length, self.model.config.sequence_length)
        batch_size = len(self.slots)
        # Initialize input_ids and attention_mask with padding (to make them all the same size)
        input_ids = torch.full((batch_size, seq_length), self.tokenizer.pad_token_id, dtype=torch.int64)
        attention_mask = torch.full((batch_size, seq_length), 0, dtype=torch.int64)

        # Pause previously active slots during generation and store their last token.
        next_tokens = []
        for slot in active_slots:
            next_tokens.append(slot.next_token)
            slot.pause()
        # Each slot must be reset with the padded inputs and masks
        for i, slot in enumerate(self.slots):
            assert slot.state != slot.state.EMPTY

            truncation = min(tokenized_inputs.input_ids.size(-1), input_ids.size(-1))
            if slot.truncate > 0:
                truncation = min(truncation, slot.truncate)
            input_ids[i, -truncation:] = tokenized_inputs.input_ids[i, -truncation:]
            slot_input_ids = input_ids[i : i + 1, :]
            # Padded input ids are also required to set logits processors and stopping criterias
            try:
                selector = TokenSelector.create(
                    slot_input_ids,
                    slot.generation_config,
                    self.model,
                    self.model.config.sequence_length,
                    seed=slot.seed,
                )
            except ValueError as e:
                # This is very unlikely, but it seems it could be possible if router does not check values beforehand.
                # In that case, we just skip the slot, and mark it as empty. This should prevent returning this to the
                # router.
                logger.error(f"Invalid generation parameters for slot {slot.id}. Skipping it. Error: {e}")
                slot.clear()
                continue
            slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
            attention_mask[i, -truncation:] = tokenized_inputs.attention_mask[i, -truncation:]
            if self._supports_static_cache:
                # Attention mask does not need to be tracked when using static cache
                slot_attention_mask = None
            else:
                slot_attention_mask = attention_mask[i]
            slot.reset(slot_input_ids, slot_attention_mask, selector)
        # Clear KV cache
        self.past_key_values = None
        # Obtain position ids using attention mask.
        position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
        # Save position id for every slot
        for slot, position_id in zip(self.slots, position_ids):
            slot.position_id = position_id.max().item() + 1

        extra_args = {}
        if self._supports_static_cache:
            self.past_key_values = StaticCacheXla(
                config=self.model.config,
                max_batch_size=len(self.slots),
                max_cache_len=self.model.config.sequence_length,
                device=self.model.device,
                dtype=self.model.dtype,
            )
            extra_args["cache_position"] = torch.arange(seq_length, device=self.model.device)
            extra_args["past_key_values"] = self.past_key_values
        else:
            # Reset/clear KV cache
            self.past_key_values = None
        generation, next_batch = self._generate_token(
            self.batch_id,
            input_ids.to(self.model.device),
            self.model,
            attention_mask=attention_mask.to(self.model.device),
            position_ids=position_ids.to(self.model.device),
            **extra_args,
        )
        self.batch_id += 1

        # Reactivate previously active slots for the next decode, and append
        # back their next token.
        for slot, next_token in zip(active_slots, next_tokens):
            slot.append(next_token)
            slot.resume()
        logger.debug("Model ready for decoding")
        if next_batch is not None:
            logger.debug(f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}")
        return generation, next_batch