in scripts/imagenet/utils.py [0:0]
def create_test_transforms(config, crop, scale, ten_crops):
normalize = transforms.Normalize(mean=config["mean"], std=config["std"])
val_transforms = []
if scale != -1:
val_transforms.append(transforms.Resize(scale))
if ten_crops:
val_transforms += [
transforms.TenCrop(crop),
transforms.Lambda(
lambda crops: [transforms.ToTensor()(crop) for crop in crops]
),
transforms.Lambda(lambda crops: [normalize(crop) for crop in crops]),
transforms.Lambda(lambda crops: torch.stack(crops)),
]
else:
val_transforms += [
transforms.CenterCrop(crop),
transforms.ToTensor(),
normalize,
]
return val_transforms