in src/utils.py [0:0]
def compute_metrics(p):
logits, labels = p
predictions = np.argmax(logits, axis=1)
# Combine metrics
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")
# Calculate metrics
accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
precision = precision_metric.compute(predictions=predictions, references=labels, average="weighted")
recall = recall_metric.compute(predictions=predictions, references=labels, average="weighted")
f1 = f1_metric.compute(predictions=predictions, references=labels, average="weighted")
return {
"accuracy": accuracy["accuracy"],
"precision": precision["precision"],
"recall": recall["recall"],
"f1": f1["f1"]
}