def main()

in src/biencoder_predict_qa.py [0:0]


def main():
    parser = argparse.ArgumentParser('Generate predictions on SQuAD.')
    parser.add_argument('dataset_name', choices=['squad', 'mrqa'])
    parser.add_argument('output_file', help='Where to write predictions JSON')
    parser.add_argument('load_dir', help='Directory containing trained model checkpoint')
    parser.add_argument('--batch-size', '-b', default=32, type=int, help='Maximum batch size')
    parser.add_argument('--max-tokens', '-m', default=8800, type=int, help='Maximum tokens per batch')
    parser.add_argument('--context-max-len', type=int, default=456)
    parser.add_argument('--context-stride', type=int, default=328)  # Stride = distance between windows in sliding window; this makes the overlap region 456 - 328 = 128
    parser.add_argument('--question-max-len', type=int, default=456)  # Longer for MRQA
    args = parser.parse_args()

    roberta = RobertaModel.from_pretrained(args.load_dir, checkpoint_file='model.pt')
    roberta.to('cuda')
    roberta.eval()
    heads = {}
    for name in ['start', 'end']:
        head = MockClassificationHead.load_from_file(
                os.path.join(args.load_dir, f'model_qa_head_{name}.pt'),
                do_token_level=True)
        head.to('cuda')
        heads[name] = head
    print('Finished loading model.')

    print(f'Reading dataset {args.dataset_name}')
    if args.dataset_name == 'squad':
        examples = read_squad(SQUAD_DEV_FILE, roberta, args)
        run_prediction(roberta, heads, examples, args)
        subprocess.check_call(['python', 'data/squad/evaluate-v2.0.py', SQUAD_DEV_FILE, 
                               args.output_file])
    else:  # MRQA
        mrqa_filenames, examples = read_mrqa_all(roberta, args)
        run_prediction(roberta, heads, examples, args)
        for filename in mrqa_filenames:
            print(f'Evaluating on {filename}')
            subprocess.check_call(['python', 'data/mrqa/mrqa_official_eval.py', filename, args.output_file])