def _token_encode()

in text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py [0:0]


    def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]:
        """Tokenize the input text and return the corresponding input_ids and true_length.

        Args:
            text (`str`):
                The input text to tokenize.
            max_length (`int`):
                The maximum length of the input_ids (typically from request)
        """
        if max_length == 0:
            max_length = self.model.config.sequence_length
        # Remove one to max_length because BOS is going to be added when padding
        max_length -= 1
        input_ids = self.tokenizer.encode(
            text,
            return_tensors="np",
            truncation=True,
            max_length=max_length,
            add_special_tokens=False,
        )
        # max_prefill_length must be a power of 2
        max_prefill_length = take_nearest_length(DEFAULT_PREFILL_BUCKETS, self.model.config.sequence_length)
        tokens, true_length = pad_tokens(input_ids[0],
                                         self.tokenizer.bos_token_id,
                                         self.tokenizer.pad_token_id,
                                         is_bos=True,
                                         max_prefill_length=max_prefill_length,
                                         jax_padding=True,
                                         )
        return tokens, true_length