def build_bn()

in mobile_cv/arch/fbnet_v2/basic_blocks.py [0:0]


def build_bn(name, num_channels, zero_gamma=None, gamma_beta=None, **kwargs):
    def _create_bn(bn_class):
        bn_op = bn_class(num_channels, **kwargs)
        if zero_gamma is True:
            nn.init.constant_(bn_op.weight, 0.0)
        if gamma_beta is not None:
            assert isinstance(gamma_beta, tuple)
            nn.init.constant_(bn_op.weight, gamma_beta[0])
            nn.init.constant_(bn_op.bias, gamma_beta[1])
        return bn_op

    BN_DEFAULT_MAPS = {
        # 2d
        "bn": lambda: _create_bn(nn.BatchNorm2d),
        "sync_bn": lambda: _create_bn(NaiveSyncBatchNorm),
        "naiveSyncBN": lambda: _create_bn(NaiveSyncBatchNorm),
        # 3d
        "bn3d": lambda: _create_bn(nn.BatchNorm3d),
        "naiveSyncBN3d": lambda: _create_bn(NaiveSyncBatchNorm3d),
        # 1d
        "bn1d": lambda: _create_bn(nn.BatchNorm1d),
        "naiveSyncBN1d": lambda: _create_bn(NaiveSyncBatchNorm1d),
        # any dimension
        "sync_bn_torch": lambda: _create_bn(nn.SyncBatchNorm),
        # others
        "gn": lambda: GroupNorm(num_channels=num_channels, **kwargs),
        "instance": lambda: nn.InstanceNorm2d(num_channels, **kwargs),
        "frozen_bn": lambda: FrozenBatchNorm2d(num_channels, **kwargs),
    }

    if name is None or name == "none":
        return None
    if name in BN_DEFAULT_MAPS:
        return BN_DEFAULT_MAPS[name]()

    return BN_REGISTRY.get(name)(num_channels, zero_gamma=zero_gamma, **kwargs)