def create_test_transforms()

in scripts/imagenet/utils.py [0:0]


def create_test_transforms(config, crop, scale, ten_crops):
    normalize = transforms.Normalize(mean=config["mean"], std=config["std"])

    val_transforms = []
    if scale != -1:
        val_transforms.append(transforms.Resize(scale))
    if ten_crops:
        val_transforms += [
            transforms.TenCrop(crop),
            transforms.Lambda(
                lambda crops: [transforms.ToTensor()(crop) for crop in crops]
            ),
            transforms.Lambda(lambda crops: [normalize(crop) for crop in crops]),
            transforms.Lambda(lambda crops: torch.stack(crops)),
        ]
    else:
        val_transforms += [
            transforms.CenterCrop(crop),
            transforms.ToTensor(),
            normalize,
        ]

    return val_transforms