import tokenizers
from argparse import ArgumentParser
import sentencepiece as spm
from collections import Counter
import json
import os
import datetime

try:
    from termcolor import colored

    has_color = True
except Exception:
    has_color = False


def main():
    parser = ArgumentParser("SentencePiece parity checker")
    parser.add_argument(
        "--input-file",
        "-i",
        type=str,
        required=True,
        help="Which files do you want to train from",
    )
    parser.add_argument(
        "--model-file",
        "-m",
        type=str,
        required=False,
        default=None,
        help="Use a pretrained token file",
    )
    parser.add_argument(
        "--model-prefix",
        type=str,
        default="spm_parity",
        help="Model prefix for spm_train",
    )
    parser.add_argument(
        "--vocab-size",
        "-v",
        type=int,
        default=8000,
        help="Vocab size for spm_train",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Verbosity",
    )
    parser.add_argument(
        "--train",
        action="store_true",
        help="Instead of checking the encoder part, we check the trainer part",
    )
    parser.add_argument(
        "--from-spm",
        action="store_true",
        help="Directly load the spm file with it's own normalizer",
    )

    args = parser.parse_args()

    trained = False
    if args.model_file is None:
        spm.SentencePieceTrainer.Train(
            f"--input={args.input_file} --model_prefix={args.model_prefix}"
            f" --character_coverage=1.0"
            f" --max_sentence_length=40000"
            f" --num_threads=1"
            f" --vocab_size={args.vocab_size}"
        )
        trained = True
        args.model_file = f"{args.model_prefix}.model"

    try:
        if args.train:
            check_train(args)
        else:
            check_encode(args)
    finally:
        if trained:
            os.remove(f"{args.model_prefix}.model")
            os.remove(f"{args.model_prefix}.vocab")


def check_train(args):
    sp = spm.SentencePieceProcessor()
    sp.Load(args.model_file)

    tokenizer = tokenizers.SentencePieceUnigramTokenizer()
    tokenizer.train(args.input_file, show_progress=False)

    spm_tokens = 0
    tokenizer_tokens = 0

    with open(args.input_file, "r") as f:
        for i, line in enumerate(f):
            line = line.strip()
            ids = sp.EncodeAsIds(line)

            encoded = tokenizer.encode(line)

            spm_tokens += len(ids)
            tokenizer_tokens += len(encoded.ids)

    vocab = [0 for i in range(args.vocab_size)]
    spm_vocab = [0 for i in range(args.vocab_size)]

    for token, index in tokenizer.get_vocab().items():
        vocab[index] = token

    for i in range(args.vocab_size):
        spm_vocab[i] = sp.id_to_piece(i)

    # 0 is unk in tokenizers, 0, 1, 2 are unk bos, eos in spm by default.
    for i, (token, spm_token) in enumerate(zip(vocab[1:], spm_vocab[3:])):
        if token != spm_token:
            print(f"First different token is token {i} ({token} != {spm_token})")
            break

    print(f"Tokenizer used {tokenizer_tokens}, where spm used {spm_tokens}")
    assert tokenizer_tokens < spm_tokens, "Our trainer should be at least more efficient than the SPM one"
    print("Ok our trainer is at least more efficient than the SPM one")


def check_diff(spm_diff, tok_diff, sp, tok):
    if spm_diff == list(reversed(tok_diff)):
        # AAA -> AA+A vs A+AA case.
        return True
    elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(tok_diff):
        # Second order OK
        # Barrich -> Barr + ich vs Bar + rich
        return True
    spm_reencoded = sp.encode(sp.decode(spm_diff))
    tok_reencoded = tok.encode(tok.decode(spm_diff)).ids
    if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded:
        # Type 3 error.
        # Snehagatha ->
        #       Sne, h, aga, th, a
        #       Sne, ha, gat, ha
        # Encoding the wrong with sp does not even recover what spm gave us
        # It fits tokenizer however...
        return True
    return False


def check_details(line, spm_ids, tok_ids, sp, tok):
    # Encoding can be the same with same result AAA -> A + AA vs AA + A
    # We can check that we use at least exactly the same number of tokens.
    for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
        if spm_id != tok_id:
            break
    first = i
    for i, (spm_id, tok_id) in enumerate(zip(reversed(spm_ids), reversed(tok_ids))):
        if spm_id != tok_id:
            break
    last = len(spm_ids) - i

    spm_diff = spm_ids[first:last]
    tok_diff = tok_ids[first:last]

    if check_diff(spm_diff, tok_diff, sp, tok):
        return True

    if last - first > 5:
        # We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems
        spms = Counter(spm_ids[first:last])
        toks = Counter(tok_ids[first:last])

        removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
        min_width = 3
        for i in range(last - first - min_width):
            if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
                possible_matches = [
                    k
                    for k in range(last - first - min_width)
                    if tok_ids[first + k : first + k + min_width] == spm_ids[first + i : first + i + min_width]
                ]
                for j in possible_matches:
                    if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], sp, tok) and check_details(
                        line,
                        spm_ids[first + i : last],
                        tok_ids[first + j : last],
                        sp,
                        tok,
                    ):
                        return True

    print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}")
    try:
        print(f"Tok: {[tok.decode([tok_ids[i]]) for i in range(first, last)]}")
    except Exception:
        pass

    ok_start = tok.decode(spm_ids[:first])
    ok_end = tok.decode(spm_ids[last:])
    wrong = tok.decode(spm_ids[first:last])
    print()
    if has_color:
        print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
    else:
        print(wrong)
    return False


def check_encode(args):
    sp = spm.SentencePieceProcessor()
    sp.Load(args.model_file)

    if args.from_spm:
        tok = tokenizers.SentencePieceUnigramTokenizer.from_spm(args.model_file)
    else:
        vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
        unk_id = sp.unk_id()
        tok = tokenizers.SentencePieceUnigramTokenizer(vocab, unk_id)

    perfect = 0
    imperfect = 0
    wrong = 0
    now = datetime.datetime.now
    spm_total_time = datetime.timedelta(seconds=0)
    tok_total_time = datetime.timedelta(seconds=0)
    with open(args.input_file, "r", encoding="utf-8-sig") as f:
        for i, line in enumerate(f):
            line = line.strip()

            start = now()
            ids = sp.EncodeAsIds(line)
            spm_time = now()

            encoded = tok.encode(line)
            tok_time = now()

            spm_total_time += spm_time - start
            tok_total_time += tok_time - spm_time

            if args.verbose:
                if i % 10000 == 0:
                    print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
                    print(f"SPM: {spm_total_time} - TOK: {tok_total_time}")

            if ids != encoded.ids:
                if check_details(line, ids, encoded.ids, sp, tok):
                    imperfect += 1
                    continue
                else:
                    wrong += 1
            else:
                perfect += 1

            assert (
                ids == encoded.ids
            ), f"line {i}: {line} : \n\n{ids}\n{encoded.ids}\n{list(zip(encoded.ids, encoded.tokens))}"

    print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
    total = perfect + imperfect + wrong
    print(f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}")


if __name__ == "__main__":
    main()
