def main()

in run_inference.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='')
    parser.add_argument('--data_dir', type=str, default='/home/ec2-user/efs/FID')
    parser.add_argument('--output_dir', type=str, default=None)
    parser.add_argument('--model_checkpoint', type=str, required=True)
    parser.add_argument('--test_batch_size', type=int, default=None)
    parser.add_argument('--overwrite_cache', action='store_true')
    parser.add_argument('--num_beams', type=int, default=3)
    parser.add_argument('--num_return_sequences', type=int, default=3)
    args, _ = parser.parse_known_args()

    cfg = OmegaConf.load(f'{args.config}')
    cfg.data.data_dir = args.data_dir
    if args.output_dir is None:
        args.output_dir = os.path.join('/home/ec2-user/efs/hybrid_qa_inference/', os.path.basename(args.config).replace('.yaml', ''))
        print(args.output_dir)

    if args.test_batch_size is not None:
        cfg.optim.test_batch_size = args.test_batch_size
    if args.overwrite_cache:
        cfg.data.overwrite_cache = 'true'

    print(args.output_dir)
    cfg.data.output_dir = args.output_dir
    cfg.model.model_checkpoint = args.model_checkpoint
    cfg.data.num_beams = args.num_beams
    cfg.data.num_return_sequences = args.num_return_sequences

    os.makedirs(cfg.data.output_dir, exist_ok=True)

    # set seed
    seed_everything(cfg.optim.seed)

    tokenizer = T5Tokenizer.from_pretrained(
        cfg.model.tokenizer_name if cfg.model.tokenizer_name else cfg.model.model_name,
        cache_dir=cfg.model.cache_dir,
        use_fast=cfg.model.use_fast,
    )

    model_t5 = T5(cfg, tokenizer)

    logger.info("Evaluation starts")

    test_dataloader = generate_dataloader(
        data_dir = cfg.data.data_dir,
        tokenizer = tokenizer,
        max_source_length = cfg.data.max_source_length,
        max_target_length = cfg.data.max_target_length,
        overwrite_cache = cfg.data.overwrite_cache,
        mode = "test",
        batch_size = cfg.optim.test_batch_size,
        question_type = cfg.data.question_type,
        passage_type = cfg.data.passage_type,
        enable_sql_supervision = cfg.data.enable_sql_supervision,
        cand_for_each_source = cfg.data.cand_for_each_source,
    )

    torch.cuda.empty_cache()

    best_checkpoint_file = cfg.model.model_checkpoint

    # load model
    best_checkpoint = torch.load(best_checkpoint_file, map_location=lambda storage, loc: storage)
    model_t5.load_state_dict(best_checkpoint['state_dict'])

    # test using Trainer test function
    # cfg.trainer.precision = 32
    trainer = pl.Trainer(**OmegaConf.to_container(cfg.trainer, resolve=True))
    trainer.test(model_t5, test_dataloader)