def get_normalization_layer()

in cvnets/layers/normalization_layers.py [0:0]


def get_normalization_layer(opts, num_features: int, norm_type: Optional[str] = None, num_groups: Optional[int] = None,
                            **kwargs):
    norm_type = getattr(opts, "model.normalization.name", "batch_norm") if norm_type is None else norm_type
    num_groups = getattr(opts, "model.normalization.groups", 1) if num_groups is None else num_groups
    momentum = getattr(opts, "model.normalization.momentum", 0.1)

    norm_layer = None
    norm_type = norm_type.lower() if norm_type is not None else None
    if norm_type in ['batch_norm', 'batch_norm_2d']:
        norm_layer = BatchNorm2d(num_features=num_features, momentum=momentum)
    elif norm_type == 'batch_norm_1d':
        norm_layer = BatchNorm1d(num_features=num_features, momentum=momentum)
    elif norm_type in ['sync_batch_norm', 'sbn']:
        norm_layer = SyncBatchNorm(num_features=num_features, momentum=momentum)
    elif norm_type in ['group_norm', 'gn']:
        num_groups = math.gcd(num_features, num_groups)
        norm_layer = GroupNorm(num_channels=num_features, num_groups=num_groups)
    elif norm_type in ['instance_norm', 'instance_norm_2d']:
        norm_layer = InstanceNorm2d(num_features=num_features, momentum=momentum)
    elif norm_type == "instance_norm_1d":
        norm_layer = InstanceNorm1d(num_features=num_features, momentum=momentum)
    elif norm_type in ['layer_norm', 'ln']:
        norm_layer = LayerNorm(num_features)
    elif norm_type == 'identity':
        norm_layer = Identity()
    else:
        logger.error(
            'Supported normalization layer arguments are: {}. Got: {}'.format(SUPPORTED_NORM_FNS, norm_type))
    return norm_layer