in jax-projects/big_bird/prepare_natural_questions.py [0:0]
def get_strided_contexts_and_ans(example, tokenizer, doc_stride=2048, max_length=4096, assertion=True):
# overlap will be of doc_stride - q_len
out = get_context_and_ans(example, assertion=assertion)
answer = out["answer"]
# later, removing these samples
if answer["start_token"] == -1:
return {
"example_id": example["id"],
"input_ids": [[-1]],
"labels": {
"start_token": [-1],
"end_token": [-1],
"category": ["null"],
},
}
input_ids = tokenizer(example["question"]["text"], out["context"]).input_ids
q_len = input_ids.index(tokenizer.sep_token_id) + 1
# return yes/no
if answer["category"][0] in ["yes", "no"]: # category is list with one element
inputs = []
category = []
q_indices = input_ids[:q_len]
doc_start_indices = range(q_len, len(input_ids), max_length - doc_stride)
for i in doc_start_indices:
end_index = i + max_length - q_len
slice = input_ids[i:end_index]
inputs.append(q_indices + slice)
category.append(answer["category"][0])
if slice[-1] == tokenizer.sep_token_id:
break
return {
"example_id": example["id"],
"input_ids": inputs,
"labels": {
"start_token": [-100] * len(category),
"end_token": [-100] * len(category),
"category": category,
},
}
splitted_context = out["context"].split()
complete_end_token = splitted_context[answer["end_token"]]
answer["start_token"] = len(
tokenizer(
" ".join(splitted_context[: answer["start_token"]]),
add_special_tokens=False,
).input_ids
)
answer["end_token"] = len(
tokenizer(" ".join(splitted_context[: answer["end_token"]]), add_special_tokens=False).input_ids
)
answer["start_token"] += q_len
answer["end_token"] += q_len
# fixing end token
num_sub_tokens = len(tokenizer(complete_end_token, add_special_tokens=False).input_ids)
if num_sub_tokens > 1:
answer["end_token"] += num_sub_tokens - 1
old = input_ids[answer["start_token"] : answer["end_token"] + 1] # right & left are inclusive
start_token = answer["start_token"]
end_token = answer["end_token"]
if assertion:
"""This won't match exactly because of extra gaps => visaully inspect everything"""
new = tokenizer.decode(old)
if answer["span"] != new:
print("ISSUE IN TOKENIZATION")
print("OLD:", answer["span"])
print("NEW:", new, end="\n\n")
if len(input_ids) <= max_length:
return {
"example_id": example["id"],
"input_ids": [input_ids],
"labels": {
"start_token": [answer["start_token"]],
"end_token": [answer["end_token"]],
"category": answer["category"],
},
}
q_indices = input_ids[:q_len]
doc_start_indices = range(q_len, len(input_ids), max_length - doc_stride)
inputs = []
answers_start_token = []
answers_end_token = []
answers_category = [] # null, yes, no, long, short
for i in doc_start_indices:
end_index = i + max_length - q_len
slice = input_ids[i:end_index]
inputs.append(q_indices + slice)
assert len(inputs[-1]) <= max_length, "Issue in truncating length"
if start_token >= i and end_token <= end_index - 1:
start_token = start_token - i + q_len
end_token = end_token - i + q_len
answers_category.append(answer["category"][0]) # ["short"] -> "short"
else:
start_token = -100
end_token = -100
answers_category.append("null")
new = inputs[-1][start_token : end_token + 1]
answers_start_token.append(start_token)
answers_end_token.append(end_token)
if assertion:
"""checking if above code is working as expected for all the samples"""
if new != old and new != [tokenizer.cls_token_id]:
print("ISSUE in strided for ID:", example["id"])
print("New:", tokenizer.decode(new))
print("Old:", tokenizer.decode(old), end="\n\n")
if slice[-1] == tokenizer.sep_token_id:
break
return {
"example_id": example["id"],
"input_ids": inputs,
"labels": {
"start_token": answers_start_token,
"end_token": answers_end_token,
"category": answers_category,
},
}