in src/sagemaker_defect_detection/utils/__init__.py [0:0]
def load_checkpoint(model: nn.Module, path: str, prefix: Optional[str]) -> nn.Module:
path = Path(path)
logger.info(f"path: {path}")
if path.is_dir():
path_str = str(list(path.rglob("*.ckpt"))[0])
else:
path_str = str(path)
device = "cuda" if torch.cuda.is_available() else "cpu"
state_dict = torch.load(path_str, map_location=torch.device(device))["state_dict"]
if prefix is not None:
if prefix[-1] != ".":
prefix += "."
state_dict = {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)}
model.load_state_dict(state_dict, strict=True)
return model