def train()

in notebooks/src/code/train.py [0:0]


def train(model_args, data_args, training_args):
    logger.info("Creating config and model")
    _, model, tokenizer = get_model(model_args, data_args)

    # Tokenizer check: this script requires a fast tokenizer.
    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        raise ValueError(
            "This example script only works for models that have a fast tokenizer. See the list "
            "at https://huggingface.co/transformers/index.html#supported-frameworks for details."
        )

    logger.info("Loading datasets")
    datasets = data.get_datasets(data_args, tokenizer)
    if datasets.train_dataset:
        logger.info(f"train dataset has {len(datasets.train_dataset)} samples")
    else:
        logger.info("No training dataset provided")
    if datasets.eval_dataset:
        logger.info(f"validation dataset has {len(datasets.eval_dataset)} samples")
    else:
        logger.info("No validation dataset provided")

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            logger.warning("No previous checkpoint found: training from scratch")
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this "
                "behavior, create the training job with an empty `checkpoint_s3_uri` or none."
            )

    logger.info("Setting up trainer")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=datasets.train_dataset,
        eval_dataset=datasets.eval_dataset if data_args.validation else None,
        tokenizer=tokenizer,
        data_collator=datasets.data_collator,
        callbacks=[
            EarlyStoppingCallback(
                early_stopping_patience=training_args.early_stopping_patience,
                early_stopping_threshold=training_args.early_stopping_threshold,
            )
        ]
        if (
            training_args.early_stopping_patience is not None
            or training_args.early_stopping_threshold is not None
        )
        else [],
        compute_metrics=datasets.metric_computer,
    )

    if not training_args.do_train:
        logger.warning(f"Training skipped (args.do_train={training_args.do_train})")
    else:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        trainer.save_model()  # (Saves the tokenizer too)

        max_train_samples = (
            data_args.max_train_samples
            if data_args.max_train_samples is not None
            else len(datasets.train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(datasets.train_dataset))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    logger.info(f"Saving model to {training_args.model_dir}")
    trainer.save_model(training_args.model_dir)  # (Saves the tokenizer too)

    # To enable directly deploying this model via SageMaker SDK's Estimator.deploy() (rather than
    # needing to create a PyTorchModel with entry_point / source_dir args), we need to save any
    # inference handler function code to model_dir/code. Here we compromise efficiency to the
    # benefit of usage simplicity, by just copying the contents of this training code folder to the
    # model/code folder for inference:
    code_path = os.path.join(training_args.model_dir, "code")
    logger.info(f"Copying code to {code_path} for inference")
    for currpath, _, files in os.walk("."):
        for file in files:
            # Skip any filenames starting with dot:
            if file.startswith("."):
                continue
            filepath = os.path.join(currpath, file)
            # Skip any pycache or dot folders:
            if ((os.path.sep + ".") in filepath) or ("__pycache__" in filepath):
                continue
            relpath = filepath[len(".") :]
            if relpath.startswith(os.path.sep):
                relpath = relpath[1:]
            outpath = os.path.join(code_path, relpath)
            logger.info(f"Copying {filepath} to {outpath}")
            os.makedirs(outpath.rpartition(os.path.sep)[0], exist_ok=True)
            shutil.copy2(filepath, outpath)

    return trainer