def main()

in anli/run_causal_lm.py [0:0]


def main(model_args, data_args, training_args):

    if data_args.eval_data_file is None and training_args.do_eval:
        raise ValueError(
            "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
            "or remove the --do_eval argument."
        )

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    if model_args.config_name:
        config = AutoConfig.from_pretrained(
            model_args.config_name, cache_dir=model_args.cache_dir
        )
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(
            model_args.model_name_or_path, cache_dir=model_args.cache_dir
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name, cache_dir=model_args.cache_dir
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path, cache_dir=model_args.cache_dir
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
            "and load it from here, using --tokenizer_name"
        )
    # modify model loading here
    # if model_args.model_name_or_path:
    logger.info(f"Initializing Causal LM model ...")
    # config.is_decoder = True
    model = AutoModelForCausalLM.from_config(config)
    if model_args.load_nli_model:
        logger.info(f"Loading NLI model {model_args.model_class_name}")
        model_class_item = MODEL_CLASSES[model_args.model_class_name]
        model_name = model_class_item["model_name"]
        nli_model = model_class_item["sequence_classification"].from_pretrained(
            model_name, cache_dir=str(nli_config.PRO_ROOT / "trans_cache"), num_labels=3
        )
        logger.info(
            f"Loading NLI pre-trained weights from {model_args.model_checkpoint_path}"
        )
        nli_model.load_state_dict(torch.load(model_args.model_checkpoint_path))
        model.base_model.load_state_dict(nli_model.base_model.state_dict())
    else:
        logger.info("Not loading any NLI models, using the previous pre-trained models")
    # Freeze encoder params
    logger.info("Freezing params ...")
    for param in model.base_model.parameters():
        param.requires_grad = False
    tot_param = [n for n, p in model.named_parameters() if p.requires_grad]
    logger.info(f"Total learnable params = {len(tot_param)}")
    logger.info(",".join(tot_param))
    # else:
    #     logger.info("Training new model from scratch")
    #     model = AutoModelWithLMHead.from_config(config)

    model.resize_token_embeddings(len(tokenizer))

    # if (
    #     config.model_type in ["bert", "roberta", "distilbert", "camembert"]
    #     and not data_args.mlm
    # ):
    #     raise ValueError(
    #         "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the"
    #         "--mlm flag (masked language modeling)."
    #     )

    if data_args.block_size <= 0:
        data_args.block_size = tokenizer.max_len
        # Our input block size will be the max possible for the model
    else:
        data_args.block_size = min(data_args.block_size, tokenizer.max_len)

    # Get datasets

    train_dataset = (
        get_dataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir)
        if training_args.do_train
        else None
    )
    if train_dataset is not None:
        logger.info(f"Read training dataset : {len(train_dataset)}")
    eval_dataset = (
        get_dataset(
            data_args,
            tokenizer=tokenizer,
            evaluate=True,
            cache_dir=model_args.cache_dir,
        )
        if training_args.do_eval
        else None
    )
    if eval_dataset is not None:
        logger.info(f"Read eval dataset : {len(eval_dataset)}")
    if config.model_type == "xlnet":
        data_collator = DataCollatorForPermutationLanguageModeling(
            tokenizer=tokenizer,
            plm_probability=data_args.plm_probability,
            max_span_length=data_args.max_span_length,
        )
    else:
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=data_args.mlm,
            mlm_probability=data_args.mlm_probability,
        )

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        prediction_loss_only=True,
    )

    # Training
    if training_args.do_train:
        model_path = (
            model_args.model_name_or_path
            if model_args.model_name_or_path is not None
            and os.path.isdir(model_args.model_name_or_path)
            else None
        )

        # def my_hp_space(trial):
        #     return {
        #         "learning_rate": trial.suggest_float(
        #             "learning_rate", 1e-4, 1e-2, log=True
        #         ),
        #         "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
        #         "per_device_train_batch_size": trial.suggest_categorical(
        #             "per_device_train_batch_size", [4, 8, 16, 32, 64]
        #         ),
        #     }

        # trainer.hyperparameter_search(direction="maximize", hp_space=my_hp_space)

        trainer.train(model_path=model_path)
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    results = {}
    if training_args.do_eval:
        ## load the saved decoder
        logger.info("Loading decoder...")
        if len(model_args.checkpoint_file) == 0:
            model.load_state_dict(
                torch.load(Path(training_args.output_dir) / "pytorch_model.bin")
            )
        else:
            model.load_state_dict(
                torch.load(
                    Path(training_args.output_dir)
                    / model_args.checkpoint_file
                    / "pytorch_model.bin"
                )
            )
        logger.info("*** Evaluate ***")
        print(f"Eval batch size : {trainer.args.eval_batch_size}")
        eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
        prediction_loss_only = True
        eval_perplexity = []
        for inputs in tqdm(eval_dataloader, desc="Eval"):
            loss, logits, labels = trainer.prediction_step(
                model, inputs, prediction_loss_only
            )
            eval_perplexity.append(math.exp(loss))

        # eval_output = trainer.evaluate()

        # perplexity = math.exp(eval_output["eval_loss"])
        result = {
            "perplexity": eval_perplexity,
            "mean_perplexity": np.mean(eval_perplexity),
        }

        output_eval_file = model_args.output_file
        if trainer.is_world_master():
            json.dump(result, open(output_eval_file, "w"))
            logger.info("***** Eval results *****")
            logger.info(f"Perplexity : {result['mean_perplexity']}")

        results.update(result)

    return results