in models/diff_augment.py [0:0]
def DiffAugment(x, policy='', normalize=False, channels_first=True):
"""Differentiable Augmentation for Data-Efficient GAN Training.
https://arxiv.org/pdf/2006.10738
"""
if policy:
if x.ndim == 5:
ndim = 5
T = x.shape[1]
x = collapse_trajectory_dim(x)
else:
ndim = 4
if normalize:
x = x * 2 - 1.0 # shift from [0, 1] to [-1, 1]
if not channels_first:
x = x.permute(0, 3, 1, 2)
for p in policy.split(','):
for f in AUGMENT_FNS[p]:
x = f(x)
if not channels_first:
x = x.permute(0, 2, 3, 1)
x = x.contiguous()
if normalize:
x = (x + 1) / 2 # shift back from [-1, 1] to [0, 1]
if ndim == 5:
x = expand_trajectory_dim(x, T)
return x