in vissl/trainer/train_task.py [0:0]
def _build_model(self, strict_load: bool = False):
"""
- Builds and returns model used for task. The returned model is not copied to
gpu yet (if using gpu) and neither wrapped with DDP yet. This is done later
by self.prepare()
- We also convert the model BatchNorm layers to SyncBatchNorm if user
has set the config option. We support PyTorch and Apex SyncBatchNorms
both.
- If the model is set to be in evaluation model and the full model must be frozen,
we freeze the model.
- If the model must be initialized from a checkpoint or user passed weights file
we initialize the model from the checkpoint or the weights.
"""
logging.info("Building model....")
# Instantiate the raw model as specified
model = build_model(self.config["MODEL"], self.config["OPTIMIZER"])
# Convert the BatchNorm layers to SyncBatchNorm if needed
# Both Apex and Pytorch SyncBatchNorms are GPU only
if (
self.config["MODEL"]["SYNC_BN_CONFIG"]["CONVERT_BN_TO_SYNC_BN"]
and self.config["MACHINE"]["DEVICE"] == "gpu"
):
model = convert_sync_bn(self.config, model)
# Enforce eval mode, no matter what the prior tranforms have done.
# For instance apex converts batch-norms and sets `requires_grad` to True
if self.config["MODEL"]["FEATURE_EVAL_SETTINGS"]["EVAL_MODE_ON"]:
if self.config["MODEL"]["FEATURE_EVAL_SETTINGS"]["FREEZE_TRUNK_ONLY"]:
logging.info(
"config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY=True, "
"will freeze trunk..."
)
model.freeze_trunk()
elif self.config["MODEL"]["FEATURE_EVAL_SETTINGS"]["FREEZE_TRUNK_AND_HEAD"]:
logging.info(
"config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_AND_HEAD=True, will "
"freeze trunk and head..."
)
model.freeze_head_and_trunk()
# assert that if the user set the PARAMS_FILE, it must exist and be valid.
if (
self.checkpoint_path is None
and self.config["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"]
):
assert g_pathmgr.exists(
self.config["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"]
), "Specified PARAMS_FILE does NOT exist"
# If we want to initialize the model in case of finetuning or evaluation,
# we do it here. But we check that there is no checkpoint existing before
# This is important in cases when the model training dies.
if (
self.checkpoint_path is None
and self.config["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"]
and g_pathmgr.exists(self.config["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"])
):
model = self._restore_model_weights(model, strict=strict_load)
return model