in cvnets/layers/normalization_layers.py [0:0]
def get_normalization_layer(opts, num_features: int, norm_type: Optional[str] = None, num_groups: Optional[int] = None,
**kwargs):
norm_type = getattr(opts, "model.normalization.name", "batch_norm") if norm_type is None else norm_type
num_groups = getattr(opts, "model.normalization.groups", 1) if num_groups is None else num_groups
momentum = getattr(opts, "model.normalization.momentum", 0.1)
norm_layer = None
norm_type = norm_type.lower() if norm_type is not None else None
if norm_type in ['batch_norm', 'batch_norm_2d']:
norm_layer = BatchNorm2d(num_features=num_features, momentum=momentum)
elif norm_type == 'batch_norm_1d':
norm_layer = BatchNorm1d(num_features=num_features, momentum=momentum)
elif norm_type in ['sync_batch_norm', 'sbn']:
norm_layer = SyncBatchNorm(num_features=num_features, momentum=momentum)
elif norm_type in ['group_norm', 'gn']:
num_groups = math.gcd(num_features, num_groups)
norm_layer = GroupNorm(num_channels=num_features, num_groups=num_groups)
elif norm_type in ['instance_norm', 'instance_norm_2d']:
norm_layer = InstanceNorm2d(num_features=num_features, momentum=momentum)
elif norm_type == "instance_norm_1d":
norm_layer = InstanceNorm1d(num_features=num_features, momentum=momentum)
elif norm_type in ['layer_norm', 'ln']:
norm_layer = LayerNorm(num_features)
elif norm_type == 'identity':
norm_layer = Identity()
else:
logger.error(
'Supported normalization layer arguments are: {}. Got: {}'.format(SUPPORTED_NORM_FNS, norm_type))
return norm_layer