in src/biencoder_predict_qa.py [0:0]
def main():
parser = argparse.ArgumentParser('Generate predictions on SQuAD.')
parser.add_argument('dataset_name', choices=['squad', 'mrqa'])
parser.add_argument('output_file', help='Where to write predictions JSON')
parser.add_argument('load_dir', help='Directory containing trained model checkpoint')
parser.add_argument('--batch-size', '-b', default=32, type=int, help='Maximum batch size')
parser.add_argument('--max-tokens', '-m', default=8800, type=int, help='Maximum tokens per batch')
parser.add_argument('--context-max-len', type=int, default=456)
parser.add_argument('--context-stride', type=int, default=328) # Stride = distance between windows in sliding window; this makes the overlap region 456 - 328 = 128
parser.add_argument('--question-max-len', type=int, default=456) # Longer for MRQA
args = parser.parse_args()
roberta = RobertaModel.from_pretrained(args.load_dir, checkpoint_file='model.pt')
roberta.to('cuda')
roberta.eval()
heads = {}
for name in ['start', 'end']:
head = MockClassificationHead.load_from_file(
os.path.join(args.load_dir, f'model_qa_head_{name}.pt'),
do_token_level=True)
head.to('cuda')
heads[name] = head
print('Finished loading model.')
print(f'Reading dataset {args.dataset_name}')
if args.dataset_name == 'squad':
examples = read_squad(SQUAD_DEV_FILE, roberta, args)
run_prediction(roberta, heads, examples, args)
subprocess.check_call(['python', 'data/squad/evaluate-v2.0.py', SQUAD_DEV_FILE,
args.output_file])
else: # MRQA
mrqa_filenames, examples = read_mrqa_all(roberta, args)
run_prediction(roberta, heads, examples, args)
for filename in mrqa_filenames:
print(f'Evaluating on {filename}')
subprocess.check_call(['python', 'data/mrqa/mrqa_official_eval.py', filename, args.output_file])