in src/dataset.py [0:0]
def getImagenetTransform(name, img_size=256, crop_size=224, normalization=True, as_list=False, differentiable=False):
transform = []
if differentiable:
if name == "random":
transform = RandomResizedCropFlip(crop_size)
elif name == "center":
transform = CenterCrop(img_size, crop_size)
else:
assert name == "none"
transform = DifferentiableDataAugmentation()
else:
if name == "random":
transform = [
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
]
elif name == "tencrop":
transform = [
transforms.Resize(img_size),
transforms.TenCrop(crop_size),
]
elif name == "center":
transform = [
transforms.Resize(img_size),
transforms.CenterCrop(crop_size),
]
else:
assert name == "none"
if name == "tencrop":
postprocess = [
transforms.Lambda(lambda crops: [transforms.ToTensor()(crop) for crop in crops])
]
else:
postprocess = [
transforms.ToTensor()
]
if normalization:
if name == "tencrop":
postprocess.append(transforms.Lambda(lambda crops: torch.stack([NORMALIZE_IMAGENET(crop) for crop in crops])))
else:
postprocess.append(NORMALIZE_IMAGENET)
if as_list:
return transform + postprocess
else:
if differentiable:
return transform, transforms.Compose(postprocess)
else:
return transforms.Compose(transform + postprocess)