in scripts/imagenet/utils.py [0:0]
def create_transforms(input_config):
"""Create transforms from configuration
Parameters
----------
input_config : dict
Dictionary containing the configuration options for input pre-processing.
Returns
-------
train_transforms : list
List of transforms to be applied to the input during training.
val_transforms : list
List of transforms to be applied to the input during validation.
"""
normalize = transforms.Normalize(mean=input_config["mean"], std=input_config["std"])
train_transforms = []
if input_config["scale_train"] != -1:
train_transforms.append(transforms.Scale(input_config["scale_train"]))
train_transforms += [
transforms.RandomResizedCrop(input_config["crop_train"]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
if input_config["color_jitter_train"]:
train_transforms.append(ColorJitter())
if input_config["lighting_train"]:
train_transforms.append(Lighting())
train_transforms.append(normalize)
val_transforms = []
if input_config["scale_val"] != -1:
val_transforms.append(transforms.Resize(input_config["scale_val"]))
val_transforms += [
transforms.CenterCrop(input_config["crop_val"]),
transforms.ToTensor(),
normalize,
]
return train_transforms, val_transforms