def DiffAugment()

in models/diff_augment.py [0:0]


def DiffAugment(x, policy='', normalize=False, channels_first=True):
    """Differentiable Augmentation for Data-Efficient GAN Training.

    https://arxiv.org/pdf/2006.10738
    """
    if policy:
        if x.ndim == 5:
            ndim = 5
            T = x.shape[1]
            x = collapse_trajectory_dim(x)
        else:
            ndim = 4

        if normalize:
            x = x * 2 - 1.0  # shift from [0, 1] to [-1, 1]

        if not channels_first:
            x = x.permute(0, 3, 1, 2)

        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)

        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()

        if normalize:
            x = (x + 1) / 2  # shift back from [-1, 1] to [0, 1]

        if ndim == 5:
            x = expand_trajectory_dim(x, T)
    return x