in src/hyperpod_nemo_adapter/utils/callbacks/checkpoint.py [0:0]
def get_interval(self, checkpoint_io, ckpt_preprocessing_duration):
"""Compute checkpoint interval
Flow:
all_reduce
+
start compute ckpt interval
| |
|<- drop -><----warmup steps ----------><---warmup steps ----->|
| ------------------------ --------------- |
| ... | | step_ckpt_duration | ...| step_duration | ... |
| ------------------------ ---------------- |
| ^ ------------- |
| | | io_duration | |
| | ------------- |
| | |
|
ckpt_preprocessing_duration
"""
if self._interval.interval > 0:
return self._interval.interval
step_duration = self._end - self._start
if step_duration == 0:
return 1
if self.step <= self.warmup_start:
return 1
if self.step <= (self.ckpt_warmup_end - 1):
ckpt_io_duration = checkpoint_io.io_duration
self.step_ckpt_durations.append(step_duration)
self.ckpt_preprocessing_durations.append(ckpt_preprocessing_duration)
self.ckpt_io_durations.append(ckpt_io_duration)
return int(self.step < (self.ckpt_warmup_end - 1))
if self.step == self.ckpt_warmup_end:
checkpoint_io.wait()
ckpt_io_duration = checkpoint_io.io_duration
self.step_ckpt_durations.append(step_duration)
self.ckpt_preprocessing_durations.append(ckpt_preprocessing_duration)
self.ckpt_io_durations.append(ckpt_io_duration)
return 0
if self.step < self.step_warmup_end:
self.step_durations.append(step_duration)
return 0
self.step_durations.append(step_duration)
assert len(self.step_durations) == self.warmup_steps
assert len(self.step_ckpt_durations) == self.warmup_steps
assert len(self.ckpt_preprocessing_durations) == self.warmup_steps
assert len(self.ckpt_io_durations) == self.warmup_steps
# Merge all durations
self._interval = compute_auto_checkpoint_interval(
self.step_durations,
self.step_ckpt_durations,
self.ckpt_preprocessing_durations,
self.ckpt_io_durations,
)
logging.info(f"[CHECKPOINT INFO] {self}")
return self._interval.interval