in Dassl.pytorch/dassl/data/transforms/transforms.py [0:0]
def _build_transform_train(cfg, choices, target_size, normalize):
print("Building transform_train")
tfm_train = []
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
input_size = cfg.INPUT.SIZE
# Make sure the image size matches the target size
conditions = []
conditions += ["random_crop" not in choices]
conditions += ["random_resized_crop" not in choices]
if all(conditions):
print(f"+ resize to {target_size}")
tfm_train += [Resize(input_size, interpolation=interp_mode)]
if "random_translation" in choices:
print("+ random translation")
tfm_train += [Random2DTranslation(input_size[0], input_size[1])]
if "random_crop" in choices:
crop_padding = cfg.INPUT.CROP_PADDING
print(f"+ random crop (padding = {crop_padding})")
tfm_train += [RandomCrop(input_size, padding=crop_padding)]
if "random_resized_crop" in choices:
s_ = cfg.INPUT.RRCROP_SCALE
print(f"+ random resized crop (size={input_size}, scale={s_})")
tfm_train += [
RandomResizedCrop(input_size, scale=s_, interpolation=interp_mode)
]
if "random_flip" in choices:
print("+ random flip")
tfm_train += [RandomHorizontalFlip()]
if "imagenet_policy" in choices:
print("+ imagenet policy")
tfm_train += [ImageNetPolicy()]
if "cifar10_policy" in choices:
print("+ cifar10 policy")
tfm_train += [CIFAR10Policy()]
if "svhn_policy" in choices:
print("+ svhn policy")
tfm_train += [SVHNPolicy()]
if "randaugment" in choices:
n_ = cfg.INPUT.RANDAUGMENT_N
m_ = cfg.INPUT.RANDAUGMENT_M
print(f"+ randaugment (n={n_}, m={m_})")
tfm_train += [RandAugment(n_, m_)]
if "randaugment_fixmatch" in choices:
n_ = cfg.INPUT.RANDAUGMENT_N
print(f"+ randaugment_fixmatch (n={n_})")
tfm_train += [RandAugmentFixMatch(n_)]
if "randaugment2" in choices:
n_ = cfg.INPUT.RANDAUGMENT_N
print(f"+ randaugment2 (n={n_})")
tfm_train += [RandAugment2(n_)]
if "colorjitter" in choices:
b_ = cfg.INPUT.COLORJITTER_B
c_ = cfg.INPUT.COLORJITTER_C
s_ = cfg.INPUT.COLORJITTER_S
h_ = cfg.INPUT.COLORJITTER_H
print(
f"+ color jitter (brightness={b_}, "
f"contrast={c_}, saturation={s_}, hue={h_})"
)
tfm_train += [
ColorJitter(
brightness=b_,
contrast=c_,
saturation=s_,
hue=h_,
)
]
if "randomgrayscale" in choices:
print("+ random gray scale")
tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)]
if "gaussian_blur" in choices:
print(f"+ gaussian blur (kernel={cfg.INPUT.GB_K})")
gb_k, gb_p = cfg.INPUT.GB_K, cfg.INPUT.GB_P
tfm_train += [RandomApply([GaussianBlur(gb_k)], p=gb_p)]
print("+ to torch tensor of range [0, 1]")
tfm_train += [ToTensor()]
if "cutout" in choices:
cutout_n = cfg.INPUT.CUTOUT_N
cutout_len = cfg.INPUT.CUTOUT_LEN
print(f"+ cutout (n_holes={cutout_n}, length={cutout_len})")
tfm_train += [Cutout(cutout_n, cutout_len)]
if "normalize" in choices:
print(
f"+ normalization (mean={cfg.INPUT.PIXEL_MEAN}, std={cfg.INPUT.PIXEL_STD})"
)
tfm_train += [normalize]
if "gaussian_noise" in choices:
print(
f"+ gaussian noise (mean={cfg.INPUT.GN_MEAN}, std={cfg.INPUT.GN_STD})"
)
tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)]
if "instance_norm" in choices:
print("+ instance normalization")
tfm_train += [InstanceNormalization()]
tfm_train = Compose(tfm_train)
return tfm_train