in cvnets/models/classification/__init__.py [0:0]
def build_classification_model(opts, *args, **kwargs):
model_name = getattr(opts, "model.classification.name", None)
model = None
is_master_node = is_master(opts)
if model_name in CLS_MODEL_REGISTRY:
cls_act_fn = getattr(opts, "model.classification.activation.name", None)
if cls_act_fn is not None:
# Override the general activation arguments
gen_act_fn = getattr(opts, "model.activation.name", "relu")
gen_act_inplace = getattr(opts, "model.activation.inplace", False)
gen_act_neg_slope = getattr(opts, "model.activation.neg_slope", 0.1)
setattr(opts, "model.activation.name", cls_act_fn)
setattr(opts, "model.activation.inplace", getattr(opts, "model.classification.activation.inplace", False))
setattr(opts, "model.activation.neg_slope", getattr(opts, "model.classification.activation.neg_slope", 0.1))
model = CLS_MODEL_REGISTRY[model_name](opts, *args, **kwargs)
# Reset activation args
setattr(opts, "model.activation.name", gen_act_fn)
setattr(opts, "model.activation.inplace", gen_act_inplace)
setattr(opts, "model.activation.neg_slope", gen_act_neg_slope)
else:
model = CLS_MODEL_REGISTRY[model_name](opts, *args, **kwargs)
else:
supported_models = list(CLS_MODEL_REGISTRY.keys())
supp_model_str = "Supported models are:"
for i, m_name in enumerate(supported_models):
supp_model_str += "\n\t {}: {}".format(i, logger.color_text(m_name))
if is_master_node:
logger.error(supp_model_str)
pretrained = getattr(opts, "model.classification.pretrained", None)
if pretrained is not None:
pretrained = get_local_path(opts, path=pretrained)
model = load_pretrained_model(model=model, wt_loc=pretrained, is_master_node=is_master_node)
freeze_norm_layers = getattr(opts, "model.classification.freeze_batch_norm", False)
if freeze_norm_layers:
model.freeze_norm_layers()
frozen_state, count_norm = check_frozen_norm_layer(model)
if count_norm > 0 and frozen_state and is_master_node:
logger.error('Something is wrong while freezing normalization layers. Please check')
if is_master_node:
logger.log("Normalization layers are frozen")
return model