in ignite/handlers/checkpoint.py [0:0]
def __call__(self, engine: Engine) -> None:
global_step = None
if self.global_step_transform is not None:
global_step = self.global_step_transform(engine, engine.last_event_name)
if self.score_function is not None:
priority = self.score_function(engine)
if not isinstance(priority, numbers.Number):
raise ValueError("Output of score_function should be a number")
else:
if global_step is None:
global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED)
priority = global_step
if self._check_lt_n_saved() or self._compare_fn(priority):
priority_str = f"{priority}" if isinstance(priority, numbers.Integral) else f"{priority:.4f}"
checkpoint = self._setup_checkpoint()
name = "checkpoint"
if len(checkpoint) == 1:
for k in checkpoint:
name = k
checkpoint = checkpoint[name]
if self.filename_pattern is None:
filename_pattern = self.setup_filename_pattern(
with_prefix=len(self.filename_prefix) > 0,
with_score=self.score_function is not None,
with_score_name=self.score_name is not None,
with_global_step=global_step is not None,
)
else:
filename_pattern = self.filename_pattern
filename_dict = {
"filename_prefix": self.filename_prefix,
"ext": self.ext,
"name": name,
"score_name": self.score_name,
"score": priority_str if (self.score_function is not None) else None,
"global_step": global_step,
}
filename = filename_pattern.format(**filename_dict)
metadata = {
"basename": f"{self.filename_prefix}{'_' * int(len(self.filename_prefix) > 0)}{name}",
"score_name": self.score_name,
"priority": priority,
}
try:
index = list(map(lambda it: it.filename == filename, self._saved)).index(True)
to_remove = True
except ValueError:
index = 0
to_remove = not self._check_lt_n_saved()
if to_remove:
item = self._saved.pop(index)
if isinstance(self.save_handler, BaseSaveHandler):
self.save_handler.remove(item.filename)
self._saved.append(Checkpoint.Item(priority, filename))
self._saved.sort(key=lambda it: it[0])
if self.include_self:
# Now that we've updated _saved, we can add our own state_dict.
checkpoint["checkpointer"] = self.state_dict()
try:
self.save_handler(checkpoint, filename, metadata)
except TypeError:
self.save_handler(checkpoint, filename)