def evaluate_using_acc()

in torchmoji/finetuning.py [0:0]


def evaluate_using_acc(model, test_gen):
    """ Evaluation function using accuracy.

    # Arguments:
        model: Model to be evaluated.
        test_gen: Testing data iterator (DataLoader)

    # Returns:
        Accuracy of the given model.
    """

    # Validate on test_data
    model.eval()
    accs = []
    for i, data in enumerate(test_gen):
        x, y = data
        outs = model(x)
        if model.nb_classes > 2:
            pred = torch.max(outs, 1)[1]
            acc = accuracy_score(y.squeeze().numpy(), pred.squeeze().numpy())
        else:
            pred = (outs >= 0).long()
            acc = (pred == y).double().sum() / len(pred)
        accs.append(acc)
    return np.mean(accs)