in torchmoji/finetuning.py [0:0]
def find_f1_threshold(model, val_gen, test_gen, average='binary'):
""" Choose a threshold for F1 based on the validation dataset
(see https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4442797/
for details on why to find another threshold than simply 0.5)
# Arguments:
model: pyTorch model
val_gen: Validation set dataloader.
test_gen: Testing set dataloader.
# Returns:
F1 score for the given data and
the corresponding F1 threshold
"""
thresholds = np.arange(0.01, 0.5, step=0.01)
f1_scores = []
model.eval()
val_out = [(y, model(X)) for X, y in val_gen]
y_val, y_pred_val = (list(t) for t in zip(*val_out))
test_out = [(y, model(X)) for X, y in test_gen]
y_test, y_pred_test = (list(t) for t in zip(*val_out))
for t in thresholds:
y_pred_val_ind = (y_pred_val > t)
f1_val = f1_score(y_val, y_pred_val_ind, average=average)
f1_scores.append(f1_val)
best_t = thresholds[np.argmax(f1_scores)]
y_pred_ind = (y_pred_test > best_t)
f1_test = f1_score(y_test, y_pred_ind, average=average)
return f1_test, best_t