def load_checkpoint()

in src/sagemaker_defect_detection/utils/__init__.py [0:0]


def load_checkpoint(model: nn.Module, path: str, prefix: Optional[str]) -> nn.Module:
    path = Path(path)
    logger.info(f"path: {path}")
    if path.is_dir():
        path_str = str(list(path.rglob("*.ckpt"))[0])
    else:
        path_str = str(path)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    state_dict = torch.load(path_str, map_location=torch.device(device))["state_dict"]
    if prefix is not None:
        if prefix[-1] != ".":
            prefix += "."

        state_dict = {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)}

    model.load_state_dict(state_dict, strict=True)
    return model