def check_details()

in bindings/python/scripts/spm_parity_check.py [0:0]


def check_details(line, spm_ids, tok_ids, sp, tok):
    # Encoding can be the same with same result AAA -> A + AA vs AA + A
    # We can check that we use at least exactly the same number of tokens.
    for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
        if spm_id != tok_id:
            break
    first = i
    for i, (spm_id, tok_id) in enumerate(zip(reversed(spm_ids), reversed(tok_ids))):
        if spm_id != tok_id:
            break
    last = len(spm_ids) - i

    spm_diff = spm_ids[first:last]
    tok_diff = tok_ids[first:last]

    if check_diff(spm_diff, tok_diff, sp, tok):
        return True

    if last - first > 5:
        # We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems
        spms = Counter(spm_ids[first:last])
        toks = Counter(tok_ids[first:last])

        removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
        min_width = 3
        for i in range(last - first - min_width):
            if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
                possible_matches = [
                    k
                    for k in range(last - first - min_width)
                    if tok_ids[first + k : first + k + min_width] == spm_ids[first + i : first + i + min_width]
                ]
                for j in possible_matches:
                    if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], sp, tok) and check_details(
                        line,
                        spm_ids[first + i : last],
                        tok_ids[first + j : last],
                        sp,
                        tok,
                    ):
                        return True

    print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}")
    try:
        print(f"Tok: {[tok.decode([tok_ids[i]]) for i in range(first, last)]}")
    except Exception:
        pass

    ok_start = tok.decode(spm_ids[:first])
    ok_end = tok.decode(spm_ids[last:])
    wrong = tok.decode(spm_ids[first:last])
    print()
    if has_color:
        print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
    else:
        print(wrong)
    return False