def main()

in relogic/logical-tabart-pretraining.py [0:0]


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if is_sagemaker:
        training_args.do_train = training_args.do_train_str == "True"
        training_args.do_eval = training_args.do_eval_str == "True"
        training_args.evaluate_during_training = training_args.evaluate_during_training_str == "True"
        data_args.train_data_file = ",".join([os.path.join(os.environ['SM_CHANNEL_TRAIN'], item) for item in data_args.train_data_file.split(",")])
        data_args.eval_data_file = ",".join([os.path.join(os.environ['SM_CHANNEL_TRAIN'], item) for item in data_args.eval_data_file.split(",")])
        training_args.output_dir = os.environ['SM_MODEL_DIR']
        model_args.pretrained_ckpt_dir = os.environ.get("SM_CHANNEL_PRETRAINED_CKPT_DIR", None)

    if model_args.pretrained_ckpt_dir is not None and model_args.load_from_pretrained_ckpt is not None:
        model_args.load_from_pretrained_ckpt = os.path.join(model_args.pretrained_ckpt_dir, model_args.load_from_pretrained_ckpt)

    if data_args.eval_data_file is None and training_args.do_eval:
        raise ValueError(
            "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
            "or remove the --do_eval argument."
        )

    if not is_sagemaker:
        if (
            os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir)
            and training_args.do_train
            and not training_args.overwrite_output_dir
        ):
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    """Initialize models and tokenizer"""
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=False)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=False)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
            "and load it from here, using --tokenizer_name"
        )
    tokenizer.add_special_tokens({"additional_special_tokens": ["<col>"]})

    model = LogicalTaBARTModel(data_args.task_names)
    model.bert.resize_token_embeddings(len(tokenizer))
    model.bert_for_texttosql.resize_token_embeddings(len(tokenizer))
    model.bert.model.shared.weight = model.bert_for_texttosql.model.shared.weight
    model.bert.model.encoder.embed_tokens.weight = model.bert_for_texttosql.model.encoder.embed_tokens.weight

    if training_args.do_eval and not training_args.do_train:
        model_param = torch.load(os.path.join(model_args.model_name_or_path, "pytorch_model.bin"))
        model.load_state_dict(model_param)
        print("All key matched and load successfully.")

    if data_args.block_size <= 0:
        data_args.block_size = tokenizer.model_max_length
        # Our input block size will be the max possible for the model
    else:
        data_args.block_size = min(data_args.block_size, tokenizer.model_max_length)

    # Get datasets

    train_datasets = get_datasets(model_args.pretraining_model, data_args, tokenizer=tokenizer) if training_args.do_train else None
    eval_datasets = get_datasets(model_args.pretraining_model, data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
    # data_collator = DataCollatorForLanguageModeling(
    #     tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
    # )
    data_collators = get_data_collators(model_args.pretraining_model, data_args, tokenizer=tokenizer)

    eos_id = None
    for data_collator in data_collators:
        if eos_id is None:
            eos_id = data_collator.label_eos_id
        else:
            assert eos_id == data_collator.label_eos_id
    match_sequence_scorer = MatchSequenceScorer(
        eos_id=eos_id, output_path=os.path.join(training_args.output_dir, "eval_dump.json"))
    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collators=data_collators,
        train_datasets=train_datasets,
        eval_datasets=eval_datasets,
        compute_metrics=match_sequence_scorer
    )

    # Training
    if training_args.do_train:
        model_path = (
            model_args.model_name_or_path
            if model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)
            else None
        )
        trainer.train(model_path=model_path)
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        eval_output = trainer.evaluate()

        perplexity = math.exp(eval_output["eval_loss"])
        result = {"perplexity": perplexity}

        output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt")
        if trainer.is_world_master():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))

        results.update(result)

    return results