in torchmoji/finetuning.py [0:0]
def evaluate_using_acc(model, test_gen):
""" Evaluation function using accuracy.
# Arguments:
model: Model to be evaluated.
test_gen: Testing data iterator (DataLoader)
# Returns:
Accuracy of the given model.
"""
# Validate on test_data
model.eval()
accs = []
for i, data in enumerate(test_gen):
x, y = data
outs = model(x)
if model.nb_classes > 2:
pred = torch.max(outs, 1)[1]
acc = accuracy_score(y.squeeze().numpy(), pred.squeeze().numpy())
else:
pred = (outs >= 0).long()
acc = (pred == y).double().sum() / len(pred)
accs.append(acc)
return np.mean(accs)