def convert_state_dict_if_from_pl()

in pytouch/utils/model_utils.py [0:0]


def convert_state_dict_if_from_pl(checkpoint):
    if "state_dict" not in checkpoint:
        _log.debug("Checkpoint is not a PyTorch-Lightning saved model.")
        return checkpoint
    else:
        _log.debug("Checkpoint is a PyTorch-Lightning saved model, extracting.")
        pl_state_dict = OrderedDict()
        for k, v in checkpoint["state_dict"].items():
            name = k
            if name.startswith("model."):
                name = name.replace("model.", "")
                pl_state_dict[name] = v
        return pl_state_dict