in text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py [0:0]
def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]:
"""Tokenize the input text and return the corresponding input_ids and true_length.
Args:
text (`str`):
The input text to tokenize.
max_length (`int`):
The maximum length of the input_ids (typically from request)
"""
if max_length == 0:
max_length = self.model.config.sequence_length
# Remove one to max_length because BOS is going to be added when padding
max_length -= 1
input_ids = self.tokenizer.encode(
text,
return_tensors="np",
truncation=True,
max_length=max_length,
add_special_tokens=False,
)
# max_prefill_length must be a power of 2
max_prefill_length = take_nearest_length(DEFAULT_PREFILL_BUCKETS, self.model.config.sequence_length)
tokens, true_length = pad_tokens(input_ids[0],
self.tokenizer.bos_token_id,
self.tokenizer.pad_token_id,
is_bos=True,
max_prefill_length=max_prefill_length,
jax_padding=True,
)
return tokens, true_length