in pytouch/utils/model_utils.py [0:0]
def convert_state_dict_if_from_pl(checkpoint):
if "state_dict" not in checkpoint:
_log.debug("Checkpoint is not a PyTorch-Lightning saved model.")
return checkpoint
else:
_log.debug("Checkpoint is a PyTorch-Lightning saved model, extracting.")
pl_state_dict = OrderedDict()
for k, v in checkpoint["state_dict"].items():
name = k
if name.startswith("model."):
name = name.replace("model.", "")
pl_state_dict[name] = v
return pl_state_dict