def main()

in modelling/src/neuraldb/run.py [0:0]


def main():
    setup_logging()

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    last_checkpoint = None
    if (
        os.path.isdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif (
            last_checkpoint is not None and training_args.resume_from_checkpoint is None
        ):
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()

    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    data_files = {}
    if data_args.train_file is not None:
        data_files["train"] = data_args.train_file
    if data_args.validation_file is not None:
        data_files["validation"] = data_args.validation_file
    if data_args.test_file is not None:
        data_files["test"] = data_args.test_file

    transformers.models.t5.tokenization_t5_fast.T5TokenizerFast.max_model_input_sizes[
        model_args.model_name_or_path
    ] = data_args.max_source_length

    config_kwargs = {}
    if "t5" in model_args.model_name_or_path:
        config_kwargs.update({"n_positions": data_args.max_source_length})

    config = AutoConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name
        else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
        max_length=data_args.max_target_length,
        **config_kwargs,
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name
        else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=False,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )


    # Temporarily set max_target_length for training.
    max_target_length = data_args.max_target_length
    padding = "max_length" if data_args.pad_to_max_length else False

    reader_cls, generator_cls, generator_kwargs, evaluation_metrics = get_generator(
        data_args.instance_generator
    )

    generators = {}
    datasets = {}
    for split, path in data_files.items():
        generator = generator_cls(
            tokenizer,
            maximum_source_length=data_args.max_source_length,
            maximum_target_length=max_target_length,
            padding=padding,
            ignore_pad_token_for_loss=data_args.ignore_pad_token_for_loss,
            test_mode=(split != "train"),
            **generator_kwargs,
        )
        dataset_reader = reader_cls(instance_generator=generator)

        generators[split] = generator
        datasets[split] = Seq2SeqDataset(
            dataset_reader.read(path),
            auto_pad=generator.encode,
        )

    compute_metrics = evaluation_metrics(
        data_args,
        tokenizer,
        generators["validation"] if "validation" in generators else generators["test"],
    )
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    if training_args.label_smoothing_factor > 0 and not hasattr(
        model, "prepare_decoder_input_ids_from_labels"
    ):
        logger.warning(
            "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
            f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
        )

    # Data collator
    label_pad_token_id = (
        -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    )
    data_collator = DataCollatorForSeq2SeqAllowMetadata(
        tokenizer,
        model=model,
        label_pad_token_id=label_pad_token_id,
        pad_to_multiple_of=8
        if training_args.fp16
        else (1024 if "led" in model_args.model_name_or_path else None),
    )
    model.resize_token_embeddings(len(tokenizer))

    if training_args.do_train or training_args.do_eval:
        # Initialize our Trainer
        trainer = NeuralDBTrainer(
            model=model,
            args=training_args,
            train_dataset=datasets["train"] if training_args.do_train else None,
            eval_dataset=datasets["validation"] if training_args.do_eval else None,
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
        )

    # Training
    if training_args.do_train:
        if last_checkpoint is not None:
            checkpoint = last_checkpoint
        elif os.path.isdir(model_args.model_name_or_path):
            checkpoint = model_args.model_name_or_path
        else:
            checkpoint = None
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()  # Saves the tokenizer too for easy upload

        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples
            if data_args.max_train_samples is not None
            else len(datasets["train"])
        )
        metrics["train_samples"] = min(max_train_samples, len(datasets["train"]))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate()
        trainer.log_metrics("eval", flatten_dicts(metrics))
        trainer.save_metrics("eval", metrics)

    if training_args.do_predict:
        logger.info("*** Test ***")

        compute_metrics = evaluation_metrics(data_args, tokenizer, generators["test"])
        tester = NeuralDBTrainer(
            model=model,
            args=training_args,
            eval_dataset=datasets["test"],
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
        )

        metrics = tester.evaluate()
        tester.log_metrics("test", flatten_dicts(metrics))
        tester.save_metrics("test", metrics)