def build_segmentation_model()

in cvnets/models/segmentation/__init__.py [0:0]


def build_segmentation_model(opts):
    seg_model_name = getattr(opts, "model.segmentation.name", None)
    model = None
    is_master_node = is_master(opts)
    if seg_model_name in SEG_MODEL_REGISTRY:
        output_stride = getattr(opts, "model.segmentation.output_stride", None)
        encoder = build_classification_model(
            opts=opts,
            output_stride=output_stride
        )

        seg_act_fn = getattr(opts, "model.segmentation.activation.name", None)
        if seg_act_fn is not None:
            # Override the general activation arguments
            gen_act_fn = getattr(opts, "model.activation.name", "relu")
            gen_act_inplace = getattr(opts, "model.activation.inplace", False)
            gen_act_neg_slope = getattr(opts, "model.activation.neg_slope", 0.1)

            setattr(opts, "model.activation.name", seg_act_fn)
            setattr(opts, "model.activation.inplace", getattr(opts, "model.segmentation.activation.inplace", False))
            setattr(opts, "model.activation.neg_slope", getattr(opts, "model.segmentation.activation.neg_slope", 0.1))

            model = SEG_MODEL_REGISTRY[seg_model_name](opts, encoder)

            # Reset activation args
            setattr(opts, "model.activation.name", gen_act_fn)
            setattr(opts, "model.activation.inplace", gen_act_inplace)
            setattr(opts, "model.activation.neg_slope", gen_act_neg_slope)
        else:
            model = SEG_MODEL_REGISTRY[seg_model_name](opts, encoder)
    else:
        supported_models = list(SEG_MODEL_REGISTRY.keys())
        if len(supported_models) == 0:
            supported_models = ["none"]
        supp_model_str = "Supported segmentation models are:"
        for i, m_name in enumerate(supported_models):
            supp_model_str += "\n\t {}: {}".format(i, logger.color_text(m_name))
        logger.error(supp_model_str)

    pretrained = getattr(opts, "model.segmentation.pretrained", None)
    if pretrained is not None:
        pretrained = get_local_path(opts, path=pretrained)
        model = load_pretrained_model(model=model, wt_loc=pretrained, is_master_node=is_master(opts))

    freeze_norm_layers = getattr(opts, "model.segmentation.freeze_batch_norm", False)
    if freeze_norm_layers:
        model.freeze_norm_layers()
        frozen_state, count_norm = check_frozen_norm_layer(model)
        if count_norm > 0 and frozen_state and is_master_node:
            logger.error('Something is wrong while freezing normalization layers. Please check')

        if is_master_node:
            logger.log("Normalization layers are frozen")

    return model