in source/sagemaker/sagemaker_graph_fraud_detection/dgl_fraud_detection/utils.py [0:0]
def get_metrics(pred, pred_proba, labels, mask, out_dir):
labels, mask = labels.asnumpy().flatten().astype(int), mask.asnumpy().flatten().astype(int)
labels, pred, pred_proba = labels[np.where(mask)], pred[np.where(mask)], pred_proba[np.where(mask)]
acc = ((pred == labels)).sum() / mask.sum()
true_pos = (np.where(pred == 1, 1, 0) + np.where(labels == 1, 1, 0) > 1).sum()
false_pos = (np.where(pred == 1, 1, 0) + np.where(labels == 0, 1, 0) > 1).sum()
false_neg = (np.where(pred == 0, 1, 0) + np.where(labels == 1, 1, 0) > 1).sum()
true_neg = (np.where(pred == 0, 1, 0) + np.where(labels == 0, 1, 0) > 1).sum()
precision = true_pos/(true_pos + false_pos) if (true_pos + false_pos) > 0 else 0
recall = true_pos/(true_pos + false_neg) if (true_pos + false_neg) > 0 else 0
f1 = 2*(precision*recall)/(precision + recall) if (precision + recall) > 0 else 0
confusion_matrix = pd.DataFrame(np.array([[true_pos, false_pos], [false_neg, true_neg]]),
columns=["labels positive", "labels negative"],
index=["predicted positive", "predicted negative"])
ap = average_precision_score(labels, pred_proba)
fpr, tpr, _ = roc_curve(labels, pred_proba)
prc, rec, _ = precision_recall_curve(labels, pred_proba)
roc_auc = auc(fpr, tpr)
pr_auc = auc(rec, prc)
save_roc_curve(fpr, tpr, roc_auc, os.path.join(out_dir, "roc_curve.png"))
save_pr_curve(prc, rec, pr_auc, ap, os.path.join(out_dir, "pr_curve.png"))
return acc, f1, precision, recall, roc_auc, pr_auc, ap, confusion_matrix