in quant/utils/checkpoints.py [0:0]
def get_path_to_checkpoint(experiment_path: Path, epoch: Optional[int] = None) -> str:
"""
Find checkpoint file path in an experiment directory.
Assume that checkpoint file names follow the `checkpoint_{epoch}.pt` format.
Args:
experiment_path: path to an experiment directory
epoch: If given tries to load that checkpoint, otherwise
loads the last checkpoint
Returns:
Path to checkpoint file
"""
ckpts_path = experiment_path / 'checkpoints'
ckpts_dict = {
int(path.name.split('_')[1].split('.')[0]): path
for path in ckpts_path.iterdir()
}
if len(ckpts_dict) == 0:
raise ValueError(
f'No checkpoint exists in the experiment directory: {experiment_path}'
)
if epoch is not None:
if epoch not in ckpts_dict.keys():
raise ValueError(f'Could not find checkpoint for epoch {epoch}.')
else:
epoch = max(ckpts_dict.keys())
return str(ckpts_dict[epoch])