def create_transforms()

in scripts/imagenet/utils.py [0:0]


def create_transforms(input_config):
    """Create transforms from configuration

    Parameters
    ----------
    input_config : dict
        Dictionary containing the configuration options for input pre-processing.

    Returns
    -------
    train_transforms : list
        List of transforms to be applied to the input during training.
    val_transforms : list
        List of transforms to be applied to the input during validation.
    """
    normalize = transforms.Normalize(mean=input_config["mean"], std=input_config["std"])

    train_transforms = []
    if input_config["scale_train"] != -1:
        train_transforms.append(transforms.Scale(input_config["scale_train"]))
    train_transforms += [
        transforms.RandomResizedCrop(input_config["crop_train"]),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
    if input_config["color_jitter_train"]:
        train_transforms.append(ColorJitter())
    if input_config["lighting_train"]:
        train_transforms.append(Lighting())
    train_transforms.append(normalize)

    val_transforms = []
    if input_config["scale_val"] != -1:
        val_transforms.append(transforms.Resize(input_config["scale_val"]))
    val_transforms += [
        transforms.CenterCrop(input_config["crop_val"]),
        transforms.ToTensor(),
        normalize,
    ]

    return train_transforms, val_transforms