def main()

in run_ranking.py [0:0]


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--gpu', nargs='*')
    parser.add_argument('--pretrained_model', default='bert-base-uncased', type=str)
    parser.add_argument('--overwrite', default=None, nargs='*')
    parser.add_argument('--do_train', action='store_true')
    parser.add_argument('--do_test', action='store_true')
    parser.add_argument('--data_dir', default='/home/ec2-user/efs/ott-qa/', type=str)
    parser.add_argument('--num_cand', default=64, type=int)
    parser.add_argument('--question_type', default='ott-qa', choices=['opensquad', 'wikisql_denotation', 'NQ-open', 'ott-qa'])

    parser.add_argument('--task_type', default='AutoModelForSequenceClassification', type=str)
    parser.add_argument('--checkpoint_dir', default='/home/ec2-user/efs/ck/ott-qa/', type=str)
    parser.add_argument('--cache_dir', default='/home/ec2-user/efs/cache/', type=str)
    parser.add_argument('--tensorboard_dir', default='/home/ec2-user/efs/wandb/', type=str)
    parser.add_argument('--load_model_checkpoint', default=None, type=str, help="The checkpoint file upon which you want to continue training on.")
    parser.add_argument('--weight_decay', default=0.01, type=float)
    parser.add_argument('--learning_rate', default=5e-5, type=float)
    parser.add_argument('--batch_size', default=16, type=int)
    parser.add_argument('--test_batch_size', default=1024, type=int)
    parser.add_argument('--adam_epsilon', default=1e-8, type=float)
    parser.add_argument('--lr_schedule', default='linear', type=str, choices=['linear', 'cosine', 'cosine_hard', 'constant'])
    parser.add_argument('--warmup_steps', default=0.1, type=float, help="if < 1, it means fraction; otherwise, means number of steps")
    parser.add_argument('--gradient_accumulation_steps', default=2, type=int)
    parser.add_argument('--num_train_epochs', default=3, type=int)
    parser.add_argument('--num_train_steps', default=10000, type=int)
    parser.add_argument('--max_length', default=150, type=int)

    # add all the available options to the trainer
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.question_type)
    logger.info(f"Checkpoint directory: {args.checkpoint_dir}")

    if args.gpu == None:
        logger.info("not using GPU")
        args.gpu = 0
    else:
        try:
            args.gpu = [int(x) for x in args.gpu]
            logger.info(f"using gpu {args.gpu}")
        except:
            ValueError("only support numerical values")

    # read pretrained model and tokenizer using config
    logger.info("loading pretrained model and tokenizer")
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.task_type]
    config = config_class.from_pretrained(args.pretrained_model, cache_dir=args.cache_dir)
    config.num_labels = 1
    tokenizer = tokenizer_class.from_pretrained(args.pretrained_model, use_fast=True, cache_dir=args.cache_dir)
    model = model_class.from_pretrained(
                args.pretrained_model,
                from_tf=False,
                config=config,
                cache_dir=args.cache_dir)

    # add special tokens
    additional_special_tokens_dict = {'additional_special_tokens': ['[title]']}
    tokenizer.add_special_tokens(additional_special_tokens_dict) # add classification tokens
    model.resize_token_embeddings(len(tokenizer))

    if args.overwrite is None:
        args.overwrite = []

    # checkpoint
    checkpoint_dir = os.path.join(args.checkpoint_dir, f'{args.pretrained_model}/')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_callback = ModelCheckpoint(monitor='avg_val_performance', filepath=checkpoint_dir+'{epoch}-{val_loss:.4f}-{avg_val_performance:.4f}', mode='max')

    # training and testing
    if args.do_train:
        # initialized dataloaders
        train_dataloader = generate_dataloader(
            args.data_dir,
            tokenizer,
            args.max_length,
            'train',
            args.num_cand,
            'train' in args.overwrite,
            args.batch_size,
            args.question_type,
        )
        val_dataloader = generate_dataloader(
            args.data_dir,
            tokenizer,
            args.max_length,
            'dev',
            args.num_cand,
            'dev' in args.overwrite,
            args.batch_size,
            args.question_type,
        )
        # test_dataloader = generate_dataloader(
        #     args.data_dir,
        #     tokenizer,
        #     args.max_length,
        #     'test',
        #     args.num_cand,
        #     'test' in args.overwrite,
        #     args.batch_size,
        #     args.question_type,
        # )
        # test_dataloader = None

        if args.num_train_steps <= 0:
            args.num_train_steps = len(train_dataloader) * args.num_train_epochs

        bert_ranker = Reranker(model, tokenizer, args)
        if args.load_model_checkpoint != None:
            logger.info(f"Loading the checkpoint {args.load_model_checkpoint} and continue training")
            model_checkpoint = torch.load(args.load_model_checkpoint, map_location=lambda storage, loc: storage)
            model_dict = model_checkpoint['state_dict']
            bert_ranker.load_state_dict(model_dict)

        tb_logger = loggers.WandbLogger(save_dir=args.tensorboard_dir, project='hybridQA-ott-qa')
        trainer = pl.Trainer(logger=tb_logger,
                             checkpoint_callback=checkpoint_callback,
                             gpus=args.gpu,
                             distributed_backend='dp',
                             val_check_interval=0.25, # check every certain % of an epoch
                             # min_epochs=args.num_train_epochs,
                             max_epochs=args.num_train_epochs,
                             max_steps=args.num_train_steps,
                             accumulate_grad_batches=args.gradient_accumulation_steps,
                             gradient_clip_val=1.0,
                             precision=args.precision)        # train
        trainer.fit(bert_ranker, train_dataloader, val_dataloader)
        # trainer.test(bert_ranker)

    if args.do_test:
        torch.cuda.empty_cache()

        # initialized dataloaders
        test_dataloader = generate_dataloader(
            args.data_dir,
            tokenizer,
            args.max_length,
            'test',
            args.num_cand,
            'test' in args.overwrite,
            args.test_batch_size,
            args.question_type,
        )

        if args.load_model_checkpoint:
            best_checkpoint_file = args.load_model_checkpoint
        else:
            # find best checkpoint
            best_val_performance = -100.
            best_val_loss = 100.
            for checkpoint_file in glob.glob(checkpoint_dir+"*avg_val_performance*.ckpt"):
                val_performance = float(checkpoint_file.split('=')[-1].replace('.ckpt',''))
                val_loss = float(checkpoint_file.split('=')[-2].split('-')[0])
                if val_performance > best_val_performance:
                    best_val_performance = val_performance
                    best_val_loss = val_loss
                    best_checkpoint_file = checkpoint_file
        logger.info(f"Loading the checkpoint: {best_checkpoint_file}")

        # load model
        bert_ranker = RerankerInference(model, tokenizer, args)
        best_checkpoint = torch.load(best_checkpoint_file, map_location=lambda storage, loc: storage)
        bert_ranker.load_state_dict(best_checkpoint['state_dict'])

        # test using Trainer test function
        trainer = pl.Trainer(gpus=args.gpu, distributed_backend='dp', benchmark=True)
        trainer.test(bert_ranker, test_dataloader)