in sagemaker/src/hf_train_deploy.py [0:0]
def train(args):
"""Model training"""
set_seed(args.seed)
train_dataset = _get_dataset(args.training_dir, "train.csv", args.text_column, args.label_column)
valid_dataset = _get_dataset(args.training_dir, "valid.csv", args.text_column, args.label_column)
# compute metrics function for binary classification
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
# download model from model hub
model = AutoModelForSequenceClassification.from_pretrained(args.model_name)
# define training args
training_args = TrainingArguments(
output_dir=args.output_data_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
warmup_steps=args.warmup_steps,
seed = args.seed,
save_steps = 500,
save_total_limit = 2,
evaluation_strategy="steps",
eval_steps = 50,
logging_steps=50,
logging_dir=args.output_data_dir,
learning_rate=float(args.learning_rate),
)
# create Trainer instance
trainer = Trainer(
model=model,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
)
# train model
trainer.train()
# evaluate model
eval_result = trainer.evaluate(eval_dataset=valid_dataset)
# writes eval result to file which can be accessed later in s3 ouput
with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer:
print(f"***** Eval results *****")
for key, value in sorted(eval_result.items()):
writer.write(f"{key} = {value}\n")
# Saves the model to s3
trainer.save_model(args.model_dir)