sagemaker/14_train_and_push_to_hub/scripts/train.py (94 lines of code) (raw):

import argparse import logging import os import random import sys import numpy as np import torch from datasets import load_from_disk, load_metric from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments from transformers.trainer_utils import get_last_checkpoint if __name__ == "__main__": parser = argparse.ArgumentParser() # hyperparameters sent by the client are passed as command-line arguments to the script. parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--train_batch_size", type=int, default=32) parser.add_argument("--eval_batch_size", type=int, default=64) parser.add_argument("--warmup_steps", type=int, default=500) parser.add_argument("--model_id", type=str) parser.add_argument("--learning_rate", type=str, default=5e-5) parser.add_argument("--fp16", type=bool, default=True) # Push to Hub Parameters parser.add_argument("--push_to_hub", type=bool, default=True) parser.add_argument("--hub_model_id", type=str, default=None) parser.add_argument("--hub_strategy", type=str, default=None) parser.add_argument("--hub_token", type=str, default=None) # Data, model, and output directories parser.add_argument("--output_data_dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) parser.add_argument("--output_dir", type=str, default=os.environ["SM_MODEL_DIR"]) parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"]) parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"]) parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"]) args, _ = parser.parse_known_args() # make sure we have required parameters to push if args.push_to_hub: if args.hub_strategy is None: raise ValueError("--hub_strategy is required when pushing to Hub") if args.hub_token is None: raise ValueError("--hub_token is required when pushing to Hub") # sets hub id if not provided if args.hub_model_id is None: args.hub_model_id = args.model_id.replace("/", "--") # Set up logging logger = logging.getLogger(__name__) logging.basicConfig( level=logging.getLevelName("INFO"), handlers=[logging.StreamHandler(sys.stdout)], format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) # load datasets train_dataset = load_from_disk(args.training_dir) test_dataset = load_from_disk(args.test_dir) logger.info(f" loaded train_dataset length is: {len(train_dataset)}") logger.info(f" loaded test_dataset length is: {len(test_dataset)}") # define metrics and metrics function metric = load_metric("accuracy") def compute_metrics(eval_pred): predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) return metric.compute(predictions=predictions, references=labels) # Prepare model labels - useful in inference API labels = train_dataset.features["labels"].names num_labels = len(labels) label2id, id2label = dict(), dict() for i, label in enumerate(labels): label2id[label] = str(i) id2label[str(i)] = label # download model from model hub model = AutoModelForSequenceClassification.from_pretrained( args.model_id, num_labels=num_labels, label2id=label2id, id2label=id2label ) tokenizer = AutoTokenizer.from_pretrained(args.model_id) # define training args training_args = TrainingArguments( output_dir=args.output_dir, overwrite_output_dir=True if get_last_checkpoint(args.output_dir) is not None else False, num_train_epochs=args.epochs, per_device_train_batch_size=args.train_batch_size, per_device_eval_batch_size=args.eval_batch_size, warmup_steps=args.warmup_steps, fp16=args.fp16, evaluation_strategy="epoch", save_strategy="epoch", save_total_limit=2, logging_dir=f"{args.output_data_dir}/logs", learning_rate=float(args.learning_rate), load_best_model_at_end=True, metric_for_best_model="accuracy", # push to hub parameters push_to_hub=args.push_to_hub, hub_strategy=args.hub_strategy, hub_model_id=args.hub_model_id, hub_token=args.hub_token, ) # create Trainer instance trainer = Trainer( model=model, args=training_args, compute_metrics=compute_metrics, train_dataset=train_dataset, eval_dataset=test_dataset, tokenizer=tokenizer, ) # train model trainer.train() # evaluate model eval_result = trainer.evaluate(eval_dataset=test_dataset) # save best model, metrics and create model card trainer.create_model_card(model_name=args.hub_model_id) trainer.push_to_hub() # Saves the model to s3 uses os.environ["SM_MODEL_DIR"] to make sure checkpointing works trainer.save_model(os.environ["SM_MODEL_DIR"])