def get_strided_contexts_and_ans()

in jax-projects/big_bird/prepare_natural_questions.py [0:0]


def get_strided_contexts_and_ans(example, tokenizer, doc_stride=2048, max_length=4096, assertion=True):
    # overlap will be of doc_stride - q_len

    out = get_context_and_ans(example, assertion=assertion)
    answer = out["answer"]

    # later, removing these samples
    if answer["start_token"] == -1:
        return {
            "example_id": example["id"],
            "input_ids": [[-1]],
            "labels": {
                "start_token": [-1],
                "end_token": [-1],
                "category": ["null"],
            },
        }

    input_ids = tokenizer(example["question"]["text"], out["context"]).input_ids
    q_len = input_ids.index(tokenizer.sep_token_id) + 1

    # return yes/no
    if answer["category"][0] in ["yes", "no"]:  # category is list with one element
        inputs = []
        category = []
        q_indices = input_ids[:q_len]
        doc_start_indices = range(q_len, len(input_ids), max_length - doc_stride)
        for i in doc_start_indices:
            end_index = i + max_length - q_len
            slice = input_ids[i:end_index]
            inputs.append(q_indices + slice)
            category.append(answer["category"][0])
            if slice[-1] == tokenizer.sep_token_id:
                break

        return {
            "example_id": example["id"],
            "input_ids": inputs,
            "labels": {
                "start_token": [-100] * len(category),
                "end_token": [-100] * len(category),
                "category": category,
            },
        }

    splitted_context = out["context"].split()
    complete_end_token = splitted_context[answer["end_token"]]
    answer["start_token"] = len(
        tokenizer(
            " ".join(splitted_context[: answer["start_token"]]),
            add_special_tokens=False,
        ).input_ids
    )
    answer["end_token"] = len(
        tokenizer(" ".join(splitted_context[: answer["end_token"]]), add_special_tokens=False).input_ids
    )

    answer["start_token"] += q_len
    answer["end_token"] += q_len

    # fixing end token
    num_sub_tokens = len(tokenizer(complete_end_token, add_special_tokens=False).input_ids)
    if num_sub_tokens > 1:
        answer["end_token"] += num_sub_tokens - 1

    old = input_ids[answer["start_token"] : answer["end_token"] + 1]  # right & left are inclusive
    start_token = answer["start_token"]
    end_token = answer["end_token"]

    if assertion:
        """This won't match exactly because of extra gaps => visaully inspect everything"""
        new = tokenizer.decode(old)
        if answer["span"] != new:
            print("ISSUE IN TOKENIZATION")
            print("OLD:", answer["span"])
            print("NEW:", new, end="\n\n")

    if len(input_ids) <= max_length:
        return {
            "example_id": example["id"],
            "input_ids": [input_ids],
            "labels": {
                "start_token": [answer["start_token"]],
                "end_token": [answer["end_token"]],
                "category": answer["category"],
            },
        }

    q_indices = input_ids[:q_len]
    doc_start_indices = range(q_len, len(input_ids), max_length - doc_stride)

    inputs = []
    answers_start_token = []
    answers_end_token = []
    answers_category = []  # null, yes, no, long, short
    for i in doc_start_indices:
        end_index = i + max_length - q_len
        slice = input_ids[i:end_index]
        inputs.append(q_indices + slice)
        assert len(inputs[-1]) <= max_length, "Issue in truncating length"

        if start_token >= i and end_token <= end_index - 1:
            start_token = start_token - i + q_len
            end_token = end_token - i + q_len
            answers_category.append(answer["category"][0])  # ["short"] -> "short"
        else:
            start_token = -100
            end_token = -100
            answers_category.append("null")
        new = inputs[-1][start_token : end_token + 1]

        answers_start_token.append(start_token)
        answers_end_token.append(end_token)
        if assertion:
            """checking if above code is working as expected for all the samples"""
            if new != old and new != [tokenizer.cls_token_id]:
                print("ISSUE in strided for ID:", example["id"])
                print("New:", tokenizer.decode(new))
                print("Old:", tokenizer.decode(old), end="\n\n")
        if slice[-1] == tokenizer.sep_token_id:
            break

    return {
        "example_id": example["id"],
        "input_ids": inputs,
        "labels": {
            "start_token": answers_start_token,
            "end_token": answers_end_token,
            "category": answers_category,
        },
    }