def test()

in sagemaker-voice-classification/notebook/train.py [0:0]


def test(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    actuals = []
    predictions = []
    prediction_probs = []
    with torch.no_grad():
        for data, target in test_loader:
            ## oversampling
            data_resampled, target_resampled = ros.fit_resample(np.squeeze(data), target)
            data = torch.from_numpy(data_resampled)
            data = data.unsqueeze_(-2)
            target = torch.tensor(target_resampled)

            data, target = data.to(device), target.to(device)
            output = model(data)
            output = output.permute(1, 0, 2)[0]
            test_loss += F.nll_loss(output, target, reduction="sum").item()  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            pred_prob = output.cpu().detach().numpy()[:,1] # get the log-probability for the second class that will be used to calculate prediction probability later using numpy exponential function
            actuals.extend(target.cpu().numpy())
            predictions.extend(pred.cpu().numpy().flatten())
            prediction_probs.extend(pred_prob)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += len(target.cpu().numpy())

    test_loss /= total
    accuracy = accuracy_score(actuals, predictions)
    rocauc = roc_auc_score(actuals, np.exp(prediction_probs))
    precision = precision_score(actuals, predictions, average='weighted')
    recall = recall_score(actuals, predictions, average='weighted')
    f1 = f1_score(actuals, predictions, average='weighted')
    f2 = fbeta_score(actuals, predictions, average='weighted', beta=0.5)
    
    print(
        "Test set: Average loss: {:.4f}, F1: {:.4f}, F2: {:.4f}, Precision: {:.4f}, Recall: {:.4f}, ROCAUC: {:.4f}, Accuracy: {:.4f}, corrected prediction ratio: {}/{}".format(
            test_loss, f1, f2, precision, recall, rocauc, accuracy, correct, total
        )
    )
    print("\n")
    return test_loss, accuracy