def test()

in pytorch_alternatives/custom_pytorch_nlp/src/main.py [0:0]


def test(model, test_loader, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.binary_cross_entropy(output, target, reduction="sum").item()
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            target_index = target.max(1, keepdim=True)[1]
            correct += pred.eq(target_index).sum().item()

    test_loss /= len(test_loader.dataset)  # Average loss over dataset samples
    print(f"val_loss: {test_loss:.4f}, val_acc: {correct/len(test_loader.dataset):.4f}")