in models/framework.py [0:0]
def load_model_from_checkpoint(model: Vocoder, checkpoint_path: str) -> None:
"""
Restore a model from a checkpoint.
Args:
model: The model to restore.
checkpoint_path: The path to the checkpoint.
"""
print(f"Loading model from checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
if torch.cuda.is_available():
checkpoint = move_state_dict_to_device(checkpoint, cpu=False)
# Check that we don't have NaN in the checkpoint.
# This should never happen; this is a sanity check.
for key, tensor in checkpoint["model"].items():
if isinstance(tensor, torch.Tensor):
assert not torch.any(
torch.isnan(tensor)
), f"Found NaN in checkpoint tensor {key}"
# All checks have passed. Load the state_dict.
model.load_state_dict(checkpoint["model"], strict=False)
for (opt, sched), opt_dict, sched_dict in zip(
model.get_optimizers(), checkpoint["optimizers"], checkpoint["lr_schedulers"]
):
opt.load_state_dict(opt_dict)
if sched is not None:
sched.load_state_dict(sched_dict)