def is_middle_token()

in src/datatuner/lm/utils.py [0:0]


def is_middle_token(tokenizer, token_str, prefix):
    try:
        tokenizer_name = str(type(tokenizer))

        if len(prefix) == 0:
            return False

        prev_token_str = tokenizer.decode(prefix[-1])

        # If the previous token is not alphanumeric, it's not a middle token
        if not prev_token_str[-1].isalnum():
            return False

        # The prev and current tokens should be of same type.
        if not (
                (prev_token_str[-1].isalpha() and token_str[0].isalpha())
                or (prev_token_str[-1].isdigit() and token_str[0].isdigit())
        ):
            return False

        if "GPT2" in tokenizer_name:
            return not (token_str[0] in [" ", "\u0120"])
        elif "OpenAIGPT" in tokenizer_name:
            return not prefix[-1].endswith("</w>")
        else:
            raise Exception("non-supported tokenizer")
    except:
        return False