def normalize_module2D()

in train_mnist.py [0:0]


def normalize_module2D(module, norm, dim):
    """
    Applies normalization `norm` to `module`.
    Optionally uses `dim`
    Returns a list of modules.
    """

    if norm == 'none':
        return [module]
    elif norm == 'batch':
        return [module, nn.BatchNorm2d(dim)]
    elif norm == 'instance':
        return [module, nn.InstanceNorm2d(dim)]
    elif norm == 'layer':
        return [module, nn.GroupNorm(1, dim)]
    elif norm == 'spectral':
        return [spectral_norm(module)]
    else:
        raise NotImplementedError('normalization [%s] is not found' % norm)