in src/biencoder_predict_qa.py [0:0]
def read_mrqa_all(roberta, args):
MAX_PARAGRAPH_LEN = min(
args.context_max_len + (args.batch_size - 1) * args.context_stride,
args.context_max_len + (args.max_tokens // args.context_max_len - 1) * args.context_stride)
print(f'Max paragraph len set to {MAX_PARAGRAPH_LEN}')
examples = []
filenames = []
for dirname in MRQA_DEV_DIRS:
for filename in sorted(glob.glob(os.path.join(dirname, '*.jsonl'))):
filenames.append(filename)
cur_paragraphs = 0
cur_truncated_paragraphs = 0
cur_questions = 0
cur_truncated_questions = 0
for line in open(filename):
line = json.loads(line.strip())
if 'header' in line:
continue
paragraph_bpe = roberta.encode(line['context'].strip())
cur_paragraphs += 1
if len(paragraph_bpe) > MAX_PARAGRAPH_LEN:
paragraph_bpe = paragraph_bpe[:MAX_PARAGRAPH_LEN]
cur_truncated_paragraphs += 1
for qa in line['qas']:
cur_questions += 1
question_bpe = roberta.encode(qa['question'].strip())
if len(question_bpe) > args.question_max_len:
question_bpe = question_bpe[:args.question_max_len]
cur_truncated_questions += 1
examples.append((qa['qid'], paragraph_bpe, question_bpe))
print(f'{filename}: Truncated {cur_truncated_paragraphs}/{cur_paragraphs} paragraphs')
print(f'{filename}: Truncated {cur_truncated_questions}/{cur_questions} questions')
print(f'Max BPE length: {max(len(x[1]) for x in examples)}')
return filenames, examples