in d2go/utils/validation_monitor.py [0:0]
def fetch_checkpoints_till_final(checkpoint_dir):
"""
A generator that yields all checkpoint paths under the given directory, it'll
keep refreshing until model_final is found.
"""
MIN_SLEEP_INTERVAL = 1.0 # in seconds
MAX_SLEEP_INTERVAL = 60.0 # in seconds
sleep_interval = MIN_SLEEP_INTERVAL
finished_checkpoints = set()
def _add_and_log(path):
finished_checkpoints.add(path)
logger.info("Found checkpoint: {}".format(path))
return path
def _log_and_sleep(sleep_interval):
logger.info(
"Sleep {} seconds while waiting for model_final.pth".format(sleep_interval)
)
time.sleep(sleep_interval)
return min(sleep_interval * 2, MAX_SLEEP_INTERVAL)
def _get_lightning_checkpoints(path: str):
return [
os.path.join(path, x)
for x in PathManager.ls(path)
if x.endswith(ModelCheckpoint.FILE_EXTENSION)
and not x.startswith(ModelCheckpoint.CHECKPOINT_NAME_LAST)
]
while True:
if not PathManager.exists(checkpoint_dir):
sleep_interval = _log_and_sleep(sleep_interval)
continue
checkpoint_paths = DetectionCheckpointer(
None, save_dir=checkpoint_dir
).get_all_checkpoint_files()
checkpoint_paths.extend(_get_lightning_checkpoints(checkpoint_dir))
final_model_path = None
periodic_checkpoints = []
for path in sorted(checkpoint_paths):
if path.endswith("model_final.pth") or path.endswith("model_final.ckpt"):
final_model_path = path
continue
if path.endswith(ModelCheckpoint.FILE_EXTENSION):
# Lightning checkpoint
model_iter = int(
re.findall(
r"(?<=step=)\d+(?={})".format(ModelCheckpoint.FILE_EXTENSION),
path,
)[0]
)
else:
model_iter = int(re.findall(r"(?<=model_)\d+(?=\.pth)", path)[0])
periodic_checkpoints.append((path, model_iter))
periodic_checkpoints = [
pc for pc in periodic_checkpoints if pc[0] not in finished_checkpoints
]
periodic_checkpoints = sorted(periodic_checkpoints, key=lambda x: x[1])
for pc in periodic_checkpoints:
yield _add_and_log(pc[0])
sleep_interval = MIN_SLEEP_INTERVAL
if final_model_path is None:
sleep_interval = _log_and_sleep(sleep_interval)
else:
yield _add_and_log(final_model_path)
break