def train()

in sagemaker/src/hf_train_deploy.py [0:0]


def train(args):
    """Model training"""
    
    set_seed(args.seed)
    
    train_dataset = _get_dataset(args.training_dir, "train.csv", args.text_column, args.label_column)
    valid_dataset = _get_dataset(args.training_dir, "valid.csv", args.text_column, args.label_column)
    
    # compute metrics function for binary classification
    def compute_metrics(pred):
        labels = pred.label_ids
        preds = pred.predictions.argmax(-1)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
        acc = accuracy_score(labels, preds)
        return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
    
    # download model from model hub
    model = AutoModelForSequenceClassification.from_pretrained(args.model_name)
    
    # define training args
    training_args = TrainingArguments(
        output_dir=args.output_data_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.train_batch_size,
        per_device_eval_batch_size=args.eval_batch_size,
        warmup_steps=args.warmup_steps,
        seed = args.seed,
        save_steps = 500,
        save_total_limit = 2,
        evaluation_strategy="steps",
        eval_steps = 50,
        logging_steps=50,
        logging_dir=args.output_data_dir,
        learning_rate=float(args.learning_rate),
    )
    
    # create Trainer instance
    trainer = Trainer(
        model=model,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
    )
    
    # train model
    trainer.train()
    
    # evaluate model
    eval_result = trainer.evaluate(eval_dataset=valid_dataset)
    
    # writes eval result to file which can be accessed later in s3 ouput
    with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer:
        print(f"***** Eval results *****")
        for key, value in sorted(eval_result.items()):
            writer.write(f"{key} = {value}\n")

    # Saves the model to s3
    trainer.save_model(args.model_dir)