in backends/gaudi/server/text_generation_server/utils/tokens.py [0:0]
def make_tokenizer_optional(tokenizer):
class _(type(tokenizer)):
def __call__(
self,
text,
return_tensors,
padding,
return_token_type_ids,
truncation,
max_length,
):
assert (
return_tensors == "pt"
), "inccorrect input arguments when calling TransparentTokenizer"
assert (
padding == "max_length" or padding == "longest"
), "inccorrect input arguments when calling TransparentTokenizer"
assert (
not return_token_type_ids
), "inccorrect input arguments when calling TransparentTokenizer"
assert (
truncation
), "inccorrect input arguments when calling TransparentTokenizer"
def str_token_to_int(i):
if i == "?":
return tokenizer.pad_token_id
else:
return int(i)
all_tokens = [
[str_token_to_int(i.strip()) for i in inner_text.split(",")]
for inner_text in text
]
if padding == "longest":
max_length = max(len(tokens) for tokens in all_tokens)
return {
"input_ids": torch.tensor(
[
[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens
for tokens in all_tokens
]
),
"attention_mask": torch.tensor(
[
[0] * (max_length - len(tokens)) + [1] * len(tokens)
for tokens in all_tokens
]
),
}
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> str:
# I don't think this method is used anywhere and should be removed when doing refactoring
return ",".join(str(i) for i in to_py_obj(token_ids)) # noqa: F821
if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
tokenizer.__class__ = _
tokenizer.is_transparent = True