in cvnets/models/segmentation/__init__.py [0:0]
def build_segmentation_model(opts):
seg_model_name = getattr(opts, "model.segmentation.name", None)
model = None
is_master_node = is_master(opts)
if seg_model_name in SEG_MODEL_REGISTRY:
output_stride = getattr(opts, "model.segmentation.output_stride", None)
encoder = build_classification_model(
opts=opts,
output_stride=output_stride
)
seg_act_fn = getattr(opts, "model.segmentation.activation.name", None)
if seg_act_fn is not None:
# Override the general activation arguments
gen_act_fn = getattr(opts, "model.activation.name", "relu")
gen_act_inplace = getattr(opts, "model.activation.inplace", False)
gen_act_neg_slope = getattr(opts, "model.activation.neg_slope", 0.1)
setattr(opts, "model.activation.name", seg_act_fn)
setattr(opts, "model.activation.inplace", getattr(opts, "model.segmentation.activation.inplace", False))
setattr(opts, "model.activation.neg_slope", getattr(opts, "model.segmentation.activation.neg_slope", 0.1))
model = SEG_MODEL_REGISTRY[seg_model_name](opts, encoder)
# Reset activation args
setattr(opts, "model.activation.name", gen_act_fn)
setattr(opts, "model.activation.inplace", gen_act_inplace)
setattr(opts, "model.activation.neg_slope", gen_act_neg_slope)
else:
model = SEG_MODEL_REGISTRY[seg_model_name](opts, encoder)
else:
supported_models = list(SEG_MODEL_REGISTRY.keys())
if len(supported_models) == 0:
supported_models = ["none"]
supp_model_str = "Supported segmentation models are:"
for i, m_name in enumerate(supported_models):
supp_model_str += "\n\t {}: {}".format(i, logger.color_text(m_name))
logger.error(supp_model_str)
pretrained = getattr(opts, "model.segmentation.pretrained", None)
if pretrained is not None:
pretrained = get_local_path(opts, path=pretrained)
model = load_pretrained_model(model=model, wt_loc=pretrained, is_master_node=is_master(opts))
freeze_norm_layers = getattr(opts, "model.segmentation.freeze_batch_norm", False)
if freeze_norm_layers:
model.freeze_norm_layers()
frozen_state, count_norm = check_frozen_norm_layer(model)
if count_norm > 0 and frozen_state and is_master_node:
logger.error('Something is wrong while freezing normalization layers. Please check')
if is_master_node:
logger.log("Normalization layers are frozen")
return model