in modeling.py [0:0]
def test_model(model, dataset, batch_size=128, augmentation=False):
"""
Evaluates `model` on samples from the specified `dataset` using the specified
`batch_size`. Returns predictions for all samples in the dataset. Optionally,
test-time data `augmentation` can be enabled as well.
"""
# create data sampler:
model.eval()
device = next(model.parameters()).device
transform = dataloading.data_augmentation(train=False) if augmentation else None
datasampler = dataloading.load_datasampler(
dataset, batch_size=batch_size, transform=transform, shuffle=False
)
# perform test pass:
predictions = []
for sample in datasampler():
# copy sample to correct device if needed:
for key in sample.keys():
if sample[key].device != device:
sample[key] = sample[key].to(device=device)
# make predictions:
with torch.no_grad():
predictions.append(model(sample["features"]))
# return all predictions:
return torch.cat(predictions, dim=0)