scripts/imagenet/transforms.py (72 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. from random import sample import torch # Default augmentation values compatible with ImageNet data augmentation pipeline _DEFAULT_ALPHASTD = 0.1 _DEFAULT_EIGVAL = [0.2175, 0.0188, 0.0045] _DEFAULT_EIGVEC = [ [-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203], ] _DEFAULT_BCS = [0.4, 0.4, 0.4] def _grayscale(img): alpha = img.new([0.299, 0.587, 0.114]) return (alpha.view(3, 1, 1) * img).sum(0, keepdim=True) def _blend(img1, img2, alpha): return img1 * alpha + (1 - alpha) * img2 class Lighting: def __init__( self, alphastd=_DEFAULT_ALPHASTD, eigval=_DEFAULT_EIGVAL, eigvec=_DEFAULT_EIGVEC ): self._alphastd = alphastd self._eigval = eigval self._eigvec = eigvec def __call__(self, img): if self._alphastd == 0.0: return img alpha = torch.normal(img.new_zeros(3), self._alphastd) eigval = img.new(self._eigval) eigvec = img.new(self._eigvec) rgb = (eigvec * alpha * eigval).sum(dim=1) return img + rgb.view(3, 1, 1) class Saturation(object): def __init__(self, var): self._var = var def __call__(self, img): gs = _grayscale(img) alpha = img.new(1).uniform_(-self._var, self._var) + 1.0 return _blend(img, gs, alpha) class Brightness(object): def __init__(self, var): self._var = var def __call__(self, img): gs = torch.zeros_like(img) alpha = img.new(1).uniform_(-self._var, self._var) + 1.0 return _blend(img, gs, alpha) class Contrast(object): def __init__(self, var): self._var = var def __call__(self, img): gs = _grayscale(img) gs = img.new_full((1, 1, 1), gs.mean()) alpha = img.new(1).uniform_(-self._var, self._var) + 1.0 return _blend(img, gs, alpha) class ColorJitter(object): def __init__( self, saturation=_DEFAULT_BCS[0], brightness=_DEFAULT_BCS[1], contrast=_DEFAULT_BCS[2], ): self._transforms = [] if saturation is not None: self._transforms.append(Saturation(saturation)) if brightness is not None: self._transforms.append(Brightness(brightness)) if contrast is not None: self._transforms.append(Contrast(contrast)) def __call__(self, img): if len(self._transforms) == 0: return img for t in sample(self._transforms, len(self._transforms)): img = t(img) return img