sagemaker/09_image_classification_vision_transformer/scripts/train.py (106 lines of code) (raw):
from transformers import ViTForImageClassification, Trainer, TrainingArguments,default_data_collator,ViTFeatureExtractor
from datasets import load_from_disk,load_metric
import random
import logging
import sys
import argparse
import os
import numpy as np
import subprocess
subprocess.run([
"git",
"config",
"--global",
"user.email",
"sagemaker@huggingface.co",
], check=True)
subprocess.run([
"git",
"config",
"--global",
"user.name",
"sagemaker",
], check=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument("--model_name", type=str)
parser.add_argument("--output_dir", type=str,default="/opt/ml/model")
parser.add_argument("--extra_model_name", type=str,default="sagemaker")
parser.add_argument("--dataset", type=str,default="cifar10")
parser.add_argument("--task", type=str,default="image-classification")
parser.add_argument("--use_auth_token", type=str, default="")
parser.add_argument("--num_train_epochs", type=int, default=3)
parser.add_argument("--per_device_train_batch_size", type=int, default=32)
parser.add_argument("--per_device_eval_batch_size", type=int, default=64)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--learning_rate", type=str, default=2e-5)
parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])
parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"])
args, _ = parser.parse_known_args()
# Set up logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.getLevelName("INFO"),
handlers=[logging.StreamHandler(sys.stdout)],
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
# load datasets
train_dataset = load_from_disk(args.training_dir)
test_dataset = load_from_disk(args.test_dir)
num_classes = train_dataset.features["label"].num_classes
logger.info(f" loaded train_dataset length is: {len(train_dataset)}")
logger.info(f" loaded test_dataset length is: {len(test_dataset)}")
metric_name = "accuracy"
# compute metrics function for binary classification
metric = load_metric(metric_name)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return metric.compute(predictions=predictions, references=labels)
# download model from model hub
model = ViTForImageClassification.from_pretrained(args.model_name,num_labels=num_classes)
# change labels
id2label = {key:train_dataset.features["label"].names[index] for index,key in enumerate(model.config.id2label.keys())}
label2id = {train_dataset.features["label"].names[index]:value for index,value in enumerate(model.config.label2id.values())}
model.config.id2label = id2label
model.config.label2id = label2id
# define training args
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
warmup_steps=args.warmup_steps,
weight_decay=args.weight_decay,
evaluation_strategy="epoch",
logging_dir=f"{args.output_dir}/logs",
learning_rate=float(args.learning_rate),
load_best_model_at_end=True,
metric_for_best_model=metric_name,
)
# create Trainer instance
trainer = Trainer(
model=model,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=default_data_collator,
)
# train model
trainer.train()
# evaluate model
eval_result = trainer.evaluate(eval_dataset=test_dataset)
# writes eval result to file which can be accessed later in s3 ouput
with open(os.path.join(args.output_dir, "eval_results.txt"), "w") as writer:
print(f"***** Eval results *****")
for key, value in sorted(eval_result.items()):
writer.write(f"{key} = {value}\n")
# Saves the model to s3
trainer.save_model(args.output_dir)
if args.use_auth_token != "":
kwargs = {
"finetuned_from": args.model_name.split("/")[1],
"tags": "image-classification",
"dataset": args.dataset,
}
repo_name = (
f"{args.model_name.split('/')[1]}-{args.task}"
if args.extra_model_name == ""
else f"{args.model_name.split('/')[1]}-{args.task}-{args.extra_model_name}"
)
trainer.push_to_hub(
repo_name=repo_name,
use_auth_token=args.use_auth_token,
**kwargs,
)