in mobile_cv/arch/fbnet_v2/basic_blocks.py [0:0]
def build_bn(name, num_channels, zero_gamma=None, gamma_beta=None, **kwargs):
def _create_bn(bn_class):
bn_op = bn_class(num_channels, **kwargs)
if zero_gamma is True:
nn.init.constant_(bn_op.weight, 0.0)
if gamma_beta is not None:
assert isinstance(gamma_beta, tuple)
nn.init.constant_(bn_op.weight, gamma_beta[0])
nn.init.constant_(bn_op.bias, gamma_beta[1])
return bn_op
BN_DEFAULT_MAPS = {
# 2d
"bn": lambda: _create_bn(nn.BatchNorm2d),
"sync_bn": lambda: _create_bn(NaiveSyncBatchNorm),
"naiveSyncBN": lambda: _create_bn(NaiveSyncBatchNorm),
# 3d
"bn3d": lambda: _create_bn(nn.BatchNorm3d),
"naiveSyncBN3d": lambda: _create_bn(NaiveSyncBatchNorm3d),
# 1d
"bn1d": lambda: _create_bn(nn.BatchNorm1d),
"naiveSyncBN1d": lambda: _create_bn(NaiveSyncBatchNorm1d),
# any dimension
"sync_bn_torch": lambda: _create_bn(nn.SyncBatchNorm),
# others
"gn": lambda: GroupNorm(num_channels=num_channels, **kwargs),
"instance": lambda: nn.InstanceNorm2d(num_channels, **kwargs),
"frozen_bn": lambda: FrozenBatchNorm2d(num_channels, **kwargs),
}
if name is None or name == "none":
return None
if name in BN_DEFAULT_MAPS:
return BN_DEFAULT_MAPS[name]()
return BN_REGISTRY.get(name)(num_channels, zero_gamma=zero_gamma, **kwargs)