in train_mnist.py [0:0]
def normalize_module2D(module, norm, dim):
"""
Applies normalization `norm` to `module`.
Optionally uses `dim`
Returns a list of modules.
"""
if norm == 'none':
return [module]
elif norm == 'batch':
return [module, nn.BatchNorm2d(dim)]
elif norm == 'instance':
return [module, nn.InstanceNorm2d(dim)]
elif norm == 'layer':
return [module, nn.GroupNorm(1, dim)]
elif norm == 'spectral':
return [spectral_norm(module)]
else:
raise NotImplementedError('normalization [%s] is not found' % norm)