in prepro.py [0:0]
def convert_qa_feature(tokenizer, question, passage, max_length,
max_n_answers, compute_span, similar_answers, args):
# passage is a single positive or negative passage. passage[0] is the title. passage[1] is the passage tokens.
# passage[2] is a list of dicts of the form {text: "answer text", "answer_start": start_pos, "word_start": word_start_pos, "word_end": word_end_pos}
# There is one dict for each occurrence of the answer
# answer_start is the character index of the answer start
# word_start is the word index of the start word
question_tokens = tokenizer.tokenize(question)
if args.pad_question:
question_tokens = question_tokens + ['[PAD]'] * max(0, args.max_question_length - len(question_tokens))
title_tokens = tokenizer.tokenize(passage[0])
passage_tokens = []
tok_to_orig_index = []
orig_to_tok_index = []
# Here, we are tokenizing the tokens using BertTokenizer. The tokens were originally created
# using BasicTokenizer. We tokenize twice since we have to map subtokens back to original tokens
# to accurately compute the answer span
for (i, token) in enumerate(passage[1]):
orig_to_tok_index.append(len(passage_tokens))
sub_tokens = tokenizer.tokenize(token)
for sub_token in sub_tokens:
tok_to_orig_index.append(i)
passage_tokens.append(sub_token)
# orig_to_tok_index is of size len(orig tokens). Each positions tells us the start of the subtoken index
# that the subtoken maps to
# tok_to_orig is of size len(subtokens). Each position tells us the index of the original token
if similar_answers:
for ans in similar_answers:
passage_tokens += ['[SEP]'] + tokenizer.tokenize(ans)
tokens = []
tokens.append("[CLS]")
for token in question_tokens:
tokens.append(token)
if len(question_tokens) > 0:
tokens.append("[SEP]")
for token in title_tokens:
tokens.append(token)
if len(title_tokens) > 0:
tokens.append("[SEP]")
for token in passage_tokens:
tokens.append(token)
truncated = 0
if len(tokens) > max_length:
tokens = tokens[:max_length]
truncated = 1
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < max_length:
input_ids.append(0)
input_mask.append(0)
assert len(input_ids) == max_length
assert len(input_mask) == max_length
# input_mask, is which positions are proper input tokens
# and which ones are padding. Padding is added above i.e. the 0s
offset = len(question_tokens) + len(title_tokens) + 3
token_to_orig_map = {token+offset:orig for token, orig in enumerate(tok_to_orig_index)}
# add offset for the question and the two [SEP] and [CLS] tokens and
# make a dict for easy access
if compute_span: # used during training
start_positions, end_positions = [], []
for answer in passage[2]: # passage[2] is a list of dicts, one dict for each answer. See above for its contents
tok_start_position = offset + orig_to_tok_index[answer['word_start']] # get subtoken index of answer word start in #format
# now deal with end
if len(orig_to_tok_index)==answer['word_end']+1:
tok_end_position = offset + orig_to_tok_index[answer['word_end']] # set it to start of the word?
else:
tok_end_position = offset + orig_to_tok_index[answer['word_end']+1]-1 # Next token start - 1
if tok_end_position > max_length: # We cant use this. Continue checking the next answer
continue
start_positions.append(tok_start_position) # Multiple start positions for this passage
end_positions.append(tok_end_position) # Multiple corresponding end positions
if len(start_positions) > max_n_answers: # truncate to maximum answers
start_positions = start_positions[:max_n_answers]
end_positions = end_positions[:max_n_answers]
answer_mask = [1 for _ in range(len(start_positions))]
# We need to have max_n_answers values. answer_mask denotes which of those are valid
for _ in range(max_n_answers-len(start_positions)): # Pad the start_positions, end_positions, and answer_mask arrays to reach max_n_answers
start_positions.append(0)
end_positions.append(0)
answer_mask.append(0)
else:
# for evaluation
start_positions, end_positions, answer_mask = None, None, None
return input_ids, input_mask, \
tokens, token_to_orig_map, \
start_positions, end_positions, answer_mask, truncated, len(question_tokens)