in src/util.py [0:0]
def accuracy_f1_at_best_thresh(gold_labels, score_list):
s_y_pairs = sorted(list(zip(score_list, gold_labels)), key=lambda x: -x[0])
# Sort so highest score is first.
# This means that first threshold is labeling everything negative
tp = 0
fp = 0
total_pos = sum(1 for y in gold_labels if y == 1)
fn = total_pos
best_err = fn
best_f1 = 0
for i, (s, y) in enumerate(s_y_pairs):
if y == 1:
tp += 1
fn -=1
else:
fp += 1
prec = tp / (i + 1)
recall = tp / total_pos
if prec == 0:
f1 = 0
else:
f1 = 2 * prec * recall / (prec + recall)
if fp + fn < best_err:
best_err = fp + fn
if f1 > best_f1:
best_f1 = f1
return (1 - best_err / len(gold_labels), best_f1)