def _batchnorm_to_instancenorm()

in opacus/validators/batch_norm.py [0:0]


def _batchnorm_to_instancenorm(module: BATCHNORM) -> INSTANCENORM:
    """
    Converts a BatchNorm module to the corresponding InstanceNorm module

    Args:
        module: BatchNorm module to be replaced

    Returns:
        InstanceNorm module that can replace the BatchNorm module provided
    """

    def match_dim():
        if isinstance(module, nn.BatchNorm1d):
            return nn.InstanceNorm1d
        elif isinstance(module, nn.BatchNorm2d):
            return nn.InstanceNorm2d
        elif isinstance(module, nn.BatchNorm3d):
            return nn.InstanceNorm3d
        elif isinstance(module, nn.SyncBatchNorm):
            raise UnsupportableModuleError(
                "There is no equivalent InstanceNorm module to replace"
                " SyncBatchNorm with. Consider replacing it with GroupNorm instead."
            )

    return match_dim()(module.num_features)