scripts/imagenet/transforms.py (72 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
from random import sample
import torch
# Default augmentation values compatible with ImageNet data augmentation pipeline
_DEFAULT_ALPHASTD = 0.1
_DEFAULT_EIGVAL = [0.2175, 0.0188, 0.0045]
_DEFAULT_EIGVEC = [
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
]
_DEFAULT_BCS = [0.4, 0.4, 0.4]
def _grayscale(img):
alpha = img.new([0.299, 0.587, 0.114])
return (alpha.view(3, 1, 1) * img).sum(0, keepdim=True)
def _blend(img1, img2, alpha):
return img1 * alpha + (1 - alpha) * img2
class Lighting:
def __init__(
self, alphastd=_DEFAULT_ALPHASTD, eigval=_DEFAULT_EIGVAL, eigvec=_DEFAULT_EIGVEC
):
self._alphastd = alphastd
self._eigval = eigval
self._eigvec = eigvec
def __call__(self, img):
if self._alphastd == 0.0:
return img
alpha = torch.normal(img.new_zeros(3), self._alphastd)
eigval = img.new(self._eigval)
eigvec = img.new(self._eigvec)
rgb = (eigvec * alpha * eigval).sum(dim=1)
return img + rgb.view(3, 1, 1)
class Saturation(object):
def __init__(self, var):
self._var = var
def __call__(self, img):
gs = _grayscale(img)
alpha = img.new(1).uniform_(-self._var, self._var) + 1.0
return _blend(img, gs, alpha)
class Brightness(object):
def __init__(self, var):
self._var = var
def __call__(self, img):
gs = torch.zeros_like(img)
alpha = img.new(1).uniform_(-self._var, self._var) + 1.0
return _blend(img, gs, alpha)
class Contrast(object):
def __init__(self, var):
self._var = var
def __call__(self, img):
gs = _grayscale(img)
gs = img.new_full((1, 1, 1), gs.mean())
alpha = img.new(1).uniform_(-self._var, self._var) + 1.0
return _blend(img, gs, alpha)
class ColorJitter(object):
def __init__(
self,
saturation=_DEFAULT_BCS[0],
brightness=_DEFAULT_BCS[1],
contrast=_DEFAULT_BCS[2],
):
self._transforms = []
if saturation is not None:
self._transforms.append(Saturation(saturation))
if brightness is not None:
self._transforms.append(Brightness(brightness))
if contrast is not None:
self._transforms.append(Contrast(contrast))
def __call__(self, img):
if len(self._transforms) == 0:
return img
for t in sample(self._transforms, len(self._transforms)):
img = t(img)
return img