def make_tokenizer_optional()

in backends/gaudi/server/text_generation_server/utils/tokens.py [0:0]


def make_tokenizer_optional(tokenizer):
    class _(type(tokenizer)):
        def __call__(
            self,
            text,
            return_tensors,
            padding,
            return_token_type_ids,
            truncation,
            max_length,
        ):
            assert (
                return_tensors == "pt"
            ), "inccorrect input arguments when calling TransparentTokenizer"
            assert (
                padding == "max_length" or padding == "longest"
            ), "inccorrect input arguments when calling TransparentTokenizer"
            assert (
                not return_token_type_ids
            ), "inccorrect input arguments when calling TransparentTokenizer"
            assert (
                truncation
            ), "inccorrect input arguments when calling TransparentTokenizer"

            def str_token_to_int(i):
                if i == "?":
                    return tokenizer.pad_token_id
                else:
                    return int(i)

            all_tokens = [
                [str_token_to_int(i.strip()) for i in inner_text.split(",")]
                for inner_text in text
            ]
            if padding == "longest":
                max_length = max(len(tokens) for tokens in all_tokens)
            return {
                "input_ids": torch.tensor(
                    [
                        [tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens
                        for tokens in all_tokens
                    ]
                ),
                "attention_mask": torch.tensor(
                    [
                        [0] * (max_length - len(tokens)) + [1] * len(tokens)
                        for tokens in all_tokens
                    ]
                ),
            }

        def decode(
            self,
            token_ids,
            skip_special_tokens: bool = False,
            clean_up_tokenization_spaces: bool = None,
            **kwargs,
        ) -> str:
            # I don't think this method is used anywhere and should be removed when doing refactoring
            return ",".join(str(i) for i in to_py_obj(token_ids))  # noqa: F821

    if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
        tokenizer.__class__ = _
        tokenizer.is_transparent = True