def check_train()

in bindings/python/scripts/spm_parity_check.py [0:0]


def check_train(args):
    sp = spm.SentencePieceProcessor()
    sp.Load(args.model_file)

    tokenizer = tokenizers.SentencePieceUnigramTokenizer()
    tokenizer.train(args.input_file, show_progress=False)

    spm_tokens = 0
    tokenizer_tokens = 0

    with open(args.input_file, "r") as f:
        for i, line in enumerate(f):
            line = line.strip()
            ids = sp.EncodeAsIds(line)

            encoded = tokenizer.encode(line)

            spm_tokens += len(ids)
            tokenizer_tokens += len(encoded.ids)

    vocab = [0 for i in range(args.vocab_size)]
    spm_vocab = [0 for i in range(args.vocab_size)]

    for token, index in tokenizer.get_vocab().items():
        vocab[index] = token

    for i in range(args.vocab_size):
        spm_vocab[i] = sp.id_to_piece(i)

    # 0 is unk in tokenizers, 0, 1, 2 are unk bos, eos in spm by default.
    for i, (token, spm_token) in enumerate(zip(vocab[1:], spm_vocab[3:])):
        if token != spm_token:
            print(f"First different token is token {i} ({token} != {spm_token})")
            break

    print(f"Tokenizer used {tokenizer_tokens}, where spm used {spm_tokens}")
    assert tokenizer_tokens < spm_tokens, "Our trainer should be at least more efficient than the SPM one"
    print("Ok our trainer is at least more efficient than the SPM one")