in opacus/validators/batch_norm.py [0:0]
def _batchnorm_to_instancenorm(module: BATCHNORM) -> INSTANCENORM:
"""
Converts a BatchNorm module to the corresponding InstanceNorm module
Args:
module: BatchNorm module to be replaced
Returns:
InstanceNorm module that can replace the BatchNorm module provided
"""
def match_dim():
if isinstance(module, nn.BatchNorm1d):
return nn.InstanceNorm1d
elif isinstance(module, nn.BatchNorm2d):
return nn.InstanceNorm2d
elif isinstance(module, nn.BatchNorm3d):
return nn.InstanceNorm3d
elif isinstance(module, nn.SyncBatchNorm):
raise UnsupportableModuleError(
"There is no equivalent InstanceNorm module to replace"
" SyncBatchNorm with. Consider replacing it with GroupNorm instead."
)
return match_dim()(module.num_features)