def find_f1_threshold()

in torchmoji/finetuning.py [0:0]


def find_f1_threshold(model, val_gen, test_gen, average='binary'):
    """ Choose a threshold for F1 based on the validation dataset
        (see https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4442797/
        for details on why to find another threshold than simply 0.5)

    # Arguments:
        model: pyTorch model
        val_gen: Validation set dataloader.
        test_gen: Testing set dataloader.

    # Returns:
        F1 score for the given data and
        the corresponding F1 threshold
    """
    thresholds = np.arange(0.01, 0.5, step=0.01)
    f1_scores = []

    model.eval()
    val_out = [(y, model(X)) for X, y in val_gen]
    y_val, y_pred_val = (list(t) for t in zip(*val_out))

    test_out = [(y, model(X)) for X, y in test_gen]
    y_test, y_pred_test = (list(t) for t in zip(*val_out))

    for t in thresholds:
        y_pred_val_ind = (y_pred_val > t)
        f1_val = f1_score(y_val, y_pred_val_ind, average=average)
        f1_scores.append(f1_val)

    best_t = thresholds[np.argmax(f1_scores)]
    y_pred_ind = (y_pred_test > best_t)
    f1_test = f1_score(y_test, y_pred_ind, average=average)
    return f1_test, best_t