def test_model()

in modeling.py [0:0]


def test_model(model, dataset, batch_size=128, augmentation=False):
    """
    Evaluates `model` on samples from the specified `dataset` using the specified
    `batch_size`. Returns predictions for all samples in the dataset. Optionally,
    test-time data `augmentation` can be enabled as well.
    """

    # create data sampler:
    model.eval()
    device = next(model.parameters()).device
    transform = dataloading.data_augmentation(train=False) if augmentation else None
    datasampler = dataloading.load_datasampler(
        dataset, batch_size=batch_size, transform=transform, shuffle=False
    )

    # perform test pass:
    predictions = []
    for sample in datasampler():

        # copy sample to correct device if needed:
        for key in sample.keys():
            if sample[key].device != device:
                sample[key] = sample[key].to(device=device)

        # make predictions:
        with torch.no_grad():
            predictions.append(model(sample["features"]))

    # return all predictions:
    return torch.cat(predictions, dim=0)