def accuracy_f1_at_best_thresh()

in src/util.py [0:0]


def accuracy_f1_at_best_thresh(gold_labels, score_list):
    s_y_pairs = sorted(list(zip(score_list, gold_labels)), key=lambda x: -x[0])
    # Sort so highest score is first. 
    # This means that first threshold is labeling everything negative
    tp = 0
    fp = 0
    total_pos = sum(1 for y in gold_labels if y == 1)
    fn = total_pos
    best_err = fn
    best_f1 = 0
    for i, (s, y) in enumerate(s_y_pairs):
        if y == 1:
            tp += 1
            fn -=1
        else:
            fp += 1
        prec = tp / (i + 1)
        recall = tp / total_pos
        if prec == 0:
            f1 = 0
        else:
            f1 = 2 * prec * recall / (prec + recall)
        if fp + fn < best_err:
            best_err = fp + fn
        if f1 > best_f1:
            best_f1 = f1
    return (1 - best_err / len(gold_labels), best_f1)