in src/common/multi_label_metrics.py [0:0]
def metrics_multi_label(targets, probs, threshold=0.5):
'''
metrics of multi-label classification
cal metrics for true matrix to predict probability matrix
:param targets: true 0-1 indicator matrix (n_samples, n_labels)
:param probs: probs 0~1 probability matrix (n_samples, n_labels)
:param thresold: negative-positive threshold
:return: some metrics
'''
targets_relevant = relevant_indexes(targets)
preds_relevant = relevant_indexes((probs >= threshold).astype(int))
acc_list = []
prec_list = []
recall_list = []
jaccard_list = []
f1_list = []
roc_auc_list = []
pr_auc_list = []
for idx in range(targets.shape[0]):
target_relevant = targets_relevant[idx]
pred_relevant = preds_relevant[idx]
target_len = len(target_relevant)
predict_len = len(pred_relevant)
union_len = len(set(target_relevant).union(set(pred_relevant)))
intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
if union_len == 0:
acc_list.append(1.0)
prec_list.append(1.0)
recall_list.append(1.0)
roc_auc_list.append(1.0)
jaccard_list.append(1.0)
f1_list.append(1.0)
pr_auc_list.append(1.0)
else:
# acc
acc = 1.0 - (union_len - intersection_len) / targets.shape[1]
acc_list.append(acc)
# precision
prec = 0.0
if predict_len > 0:
prec = intersection_len / predict_len
prec_list.append(prec)
# recall
if target_len > 0:
recall = intersection_len / target_len
else:
recall = 1.0
recall_list.append(recall)
# jaccard sim
jac = intersection_len / union_len
jaccard_list.append(jac)
# f1
if prec + recall == 0:
f1 = 0.0
else:
f1 = 2.0 * prec * recall / (prec + recall)
f1_list.append(f1)
# roc_auc
if len(np.unique(targets[idx, :])) > 1:
roc_auc = roc_auc_macro(targets[idx, :], probs[idx, :])
roc_auc_list.append(roc_auc)
pr_auc = pr_auc_macro(targets[idx, :], probs[idx, :])
pr_auc_list.append(pr_auc)
f_max_value, p_max_value, r_max_value, t_max_value, preds_max_value = f_max(targets, probs)
return {
"acc": round(sum(acc_list)/len(acc_list), 4) if len(acc_list) > 0 else 0,
"jaccard": round(sum(jaccard_list)/len(jaccard_list), 4) if len(jaccard_list) > 0 else 0,
"prec": round(sum(prec_list)/len(prec_list), 4) if len(prec_list) > 0 else 0,
"recall": round(sum(recall_list)/len(recall_list), 4) if len(recall_list) > 0 else 0,
"f1": round(sum(f1_list)/len(f1_list), 4) if len(f1_list) > 0 else 0,
"pr_auc": round(sum(pr_auc_list)/len(pr_auc_list), 4) if len(pr_auc_list) > 0 else 0,
"roc_auc": round(sum(roc_auc_list)/len(roc_auc_list), 4) if len(roc_auc_list) > 0 else 0,
"fmax": f_max_value,
"pmax": p_max_value,
"rmax": r_max_value,
"tmax": t_max_value
}