def main()

in src/run_fusion_in_decoder.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='')
    args, _ = parser.parse_known_args()

    cfg = OmegaConf.load(f'/opt/ml/code/model_arcifacts/fusion_config{args.config}.yaml')

    cfg.data.data_dir = os.environ['SM_CHANNEL_TRAIN']
    cfg.data.output_dir = os.path.join(os.environ['SM_MODEL_DIR'], 'output')
    os.makedirs(cfg.data.output_dir, exist_ok=True)
    cfg.model.checkpoint_dir = os.path.join(os.environ['SM_MODEL_DIR'], 'ckpt')
    os.makedirs(cfg.model.checkpoint_dir, exist_ok=True)

    # set seed
    seed_everything(cfg.optim.seed)

    # checkpoint
    checkpoint_dir = os.path.join(cfg.model.checkpoint_dir, cfg.model.model_name)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_callback = ModelCheckpoint(
        monitor='avg_val_loss',
        filepath=os.path.join(checkpoint_dir, '{epoch}-{val_loss:.4f}'),
        mode='min',
        save_last=False,
        save_top_k=2,
    )

    tokenizer = T5Tokenizer.from_pretrained(
        cfg.model.tokenizer_name if cfg.model.tokenizer_name else cfg.model.model_name,
        cache_dir=cfg.model.cache_dir,
        use_fast=cfg.model.use_fast,
    )

    model_t5 = T5(cfg, tokenizer)

    if cfg.model.model_checkpoint:
        logger.info(f"Loading the checkpoint {cfg.model.model_checkpoint} and continue training")
        model_checkpoint = torch.load(cfg.model.model_checkpoint, map_location=lambda storage, loc: storage)
        model_dict = model_checkpoint['state_dict']
        model_t5.load_state_dict(model_dict)

    # training and testing
    if cfg.do_train:

        train_dataloader = generate_dataloader(
            data_dir = cfg.data.data_dir,
            tokenizer = tokenizer,
            max_source_length = cfg.data.max_source_length,
            max_target_length = cfg.data.max_target_length,
            overwrite_cache = cfg.data.overwrite_cache,
            mode = "train",
            batch_size = cfg.optim.train_batch_size,
            question_type = cfg.data.question_type,
            passage_type = cfg.data.passage_type,
            enable_sql_supervision = cfg.data.enable_sql_supervision,
            cand_for_each_source = cfg.data.cand_for_each_source,
        )

        dev_dataloader = generate_dataloader(
            data_dir = cfg.data.data_dir,
            tokenizer = tokenizer,
            max_source_length = cfg.data.max_source_length,
            max_target_length = cfg.data.max_target_length,
            overwrite_cache = cfg.data.overwrite_cache,
            mode = "dev",
            batch_size = cfg.optim.dev_batch_size,
            question_type = cfg.data.question_type,
            passage_type = cfg.data.passage_type,
            enable_sql_supervision = cfg.data.enable_sql_supervision,
            cand_for_each_source = cfg.data.cand_for_each_source,
        )

        logger.info("Training starts")
        # tb_logger = loggers.WandbLogger(save_dir=cfg.optim.logging_dir, project='fusion in decoder')
        trainer = pl.Trainer(
            # logger=tb_logger,
            checkpoint_callback=checkpoint_callback,
            **OmegaConf.to_container(cfg.trainer, resolve=True),
        )
        trainer.fit(model_t5, train_dataloader, dev_dataloader)
        # trainer.test(model_t5)

    if cfg.do_eval:

        test_dataloader = generate_dataloader(
            data_dir = cfg.data.data_dir,
            tokenizer = tokenizer,
            max_source_length = cfg.data.max_source_length,
            max_target_length = cfg.data.max_target_length,
            overwrite_cache = cfg.data.overwrite_cache,
            mode = "test",
            batch_size = cfg.optim.test_batch_size,
            question_type = cfg.data.question_type,
            passage_type = cfg.data.passage_type,
            enable_sql_supervision = cfg.data.enable_sql_supervision,
            cand_for_each_source = cfg.data.cand_for_each_source,
        )

        logger.info("Evaluation starts")
        best_checkpoint_file = None
        if cfg.model.model_checkpoint == None:
            # find best checkpoint
            best_val_loss = 10000.
            for checkpoint_file in glob.glob(os.path.join(checkpoint_dir, "*val_loss*.ckpt")):
                try:
                    val_loss = float(checkpoint_file.split('=')[-1].replace(".ckpt", ""))
                except:
                    continue
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_checkpoint_file = checkpoint_file
            logger.info(f"Loading the checkpoint: {best_checkpoint_file}")
        else:
            best_checkpoint_file = cfg.model.model_checkpoint

        # load model
        if best_checkpoint_file is not None:
            best_checkpoint = torch.load(best_checkpoint_file, map_location=lambda storage, loc: storage)
            model_t5.load_state_dict(best_checkpoint['state_dict'])

        # test using Trainer test function
        trainer = pl.Trainer(**OmegaConf.to_container(cfg.trainer, resolve=True))
        trainer.test(model_t5, test_dataloader)