def load_model_from_checkpoint()

in models/framework.py [0:0]


def load_model_from_checkpoint(model: Vocoder, checkpoint_path: str) -> None:
    """
    Restore a model from a checkpoint.

    Args:
      model: The model to restore.
      checkpoint_path: The path to the checkpoint.
    """
    print(f"Loading model from checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)

    if torch.cuda.is_available():
        checkpoint = move_state_dict_to_device(checkpoint, cpu=False)

    # Check that we don't have NaN in the checkpoint.
    # This should never happen; this is a sanity check.
    for key, tensor in checkpoint["model"].items():
        if isinstance(tensor, torch.Tensor):
            assert not torch.any(
                torch.isnan(tensor)
            ), f"Found NaN in checkpoint tensor {key}"

    # All checks have passed. Load the state_dict.
    model.load_state_dict(checkpoint["model"], strict=False)

    for (opt, sched), opt_dict, sched_dict in zip(
        model.get_optimizers(), checkpoint["optimizers"], checkpoint["lr_schedulers"]
    ):
        opt.load_state_dict(opt_dict)
        if sched is not None:
            sched.load_state_dict(sched_dict)