def decode()

in text-generation-inference/server/text_generation_server/generator.py [0:0]


    def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
        """Decode the specified prefilled requests.

        Args:
            batches (`List[CachedBatch]`):
                A list of previous batches containing the prefilled requests.

        Return:
            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
        """
        # batches contains a list composed of:
        # - the batch id returned by the last decode,
        # - the batch id(s) returned by the last prefill(s)
        # Batches are always concatenated during prefill, so we can
        # just carry on with decoding. We adopt the id of the first
        # batch in the list as our next batch id.
        next_batch_id = batches[0].id
        request_ids = []
        for batch in batches:
            request_ids += batch.request_ids
        cleared_request_ids = []
        for slot in self.slots:
            if slot.state == slot.State.READY and slot.request_id not in request_ids:
                cleared_request_ids.append(slot.request_id)
                slot.clear()
        if len(cleared_request_ids) > 0:
            logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.")
        active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
        if len(active_slots) < len(request_ids):
            logger.error("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
        # Reconstruct input_ids and attention_mask from slots
        input_ids = None
        attention_mask = None
        batch_size = len(self.slots)
        position_ids = torch.zeros(
            [batch_size, 1],
            dtype=torch.int64,
        )
        # init pad_token_id and input_ids
        pad_token_id = self.tokenizer.pad_token_id
        if pad_token_id is None:
            if isinstance(self.tokenizer.eos_token_id, list):
                pad_token_id = self.tokenizer.eos_token_id[0]
            else:
                pad_token_id = self.tokenizer.eos_token_id
        # Create blank inputs covering all slots (even empty ones)
        input_ids = torch.full(
            [batch_size, 1],
            fill_value=pad_token_id,
            dtype=torch.int64,
        )
        cache_position = torch.zeros([1], dtype=torch.int64)
        for i, slot in enumerate(self.slots):
            if slot.state != Slot.State.EMPTY:
                # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
                input_ids.index_put_([torch.tensor([i])], slot.next_token)
                if not self._supports_static_cache:
                    # When using dynamic cache, the whole attention mask needs to be passed over to the model at each iteration.
                    if attention_mask is None:
                        # Create default mask covering all slots (even empty ones)
                        attention_mask = torch.zeros(
                            [batch_size, slot.attention_mask.size(-1)],
                            dtype=torch.int64,
                        )
                    attention_mask.index_put_([torch.tensor([i])], slot.attention_mask)
                position_ids.index_put_([torch.tensor([i])], torch.tensor(slot.position_id))
                cache_position = torch.maximum(cache_position, torch.tensor([slot.cache_position]))
        if input_ids is None:
            raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
        extra_args = {}
        if self._supports_static_cache:
            extra_args["cache_position"] = position_ids.max().unsqueeze(0).to(self.model.device)
        else:
            extra_args["attention_mask"] = attention_mask.to(self.model.device)
        extra_args["past_key_values"] = self.past_key_values
        generations, next_batch = self._generate_token(
            next_batch_id,
            input_ids.to(self.model.device),
            self.model_one_token,
            position_ids=position_ids.to(self.model.device),
            **extra_args,
        )
        for slot, gen in zip(self.slots, generations):
            slot.position_id += len(gen.tokens.ids)

        return generations, next_batch