in scripts/train.py [0:0]
def add_end_idx(answers, contexts):
# loop through each answer-context pair
for answer, context in zip(answers, contexts):
# gold_text refers to the answer we are expecting to find in context
gold_text = answer['text']
# we already know the start index
start_idx = answer['answer_start']
# and ideally this would be the end index...
end_idx = start_idx + len(gold_text)
# ...however, sometimes squad answers are off by a character or two
if context[start_idx:end_idx] == gold_text:
# if the answer is not off :)
answer['answer_end'] = end_idx
else:
# this means the answer is off by 1-2 tokens
for n in [1, 2]:
if context[start_idx-n:end_idx-n] == gold_text:
answer['answer_start'] = start_idx - n
answer['answer_end'] = end_idx - n