in classification/train_edu_bert.py [0:0]
def compute_metrics(eval_pred):
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")
accuracy_metric = evaluate.load("accuracy")
logits, labels = eval_pred
preds = np.round(logits.squeeze()).clip(0, 5).astype(int)
labels = np.round(labels.squeeze()).astype(int)
precision = precision_metric.compute(
predictions=preds, references=labels, average="macro"
)["precision"]
recall = recall_metric.compute(
predictions=preds, references=labels, average="macro"
)["recall"]
f1 = f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"]
accuracy = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]
report = classification_report(labels, preds)
cm = confusion_matrix(labels, preds)
print("Validation Report:\n" + report)
print("Confusion Matrix:\n" + str(cm))
return {
"precision": precision,
"recall": recall,
"f1_macro": f1,
"accuracy": accuracy,
}