in pycls/core/checkpoint.py [0:0]
def save_checkpoint(model, model_ema, optimizer, epoch, test_err, ema_err):
"""Saves a checkpoint and also the best weights so far in a best checkpoint."""
# Save checkpoints only from the main process
if not dist.is_main_proc():
return
# Ensure that the checkpoint dir exists
pathmgr.mkdirs(get_checkpoint_dir())
# Record the state
checkpoint = {
"epoch": epoch,
"test_err": test_err,
"ema_err": ema_err,
"model_state": unwrap_model(model).state_dict(),
"ema_state": unwrap_model(model_ema).state_dict(),
"optimizer_state": optimizer.state_dict(),
"cfg": cfg.dump(),
}
# Write the checkpoint
checkpoint_file = get_checkpoint(epoch + 1)
with pathmgr.open(checkpoint_file, "wb") as f:
torch.save(checkpoint, f)
# Store the best model and model_ema weights so far
if not pathmgr.exists(get_checkpoint_best()):
pathmgr.copy(checkpoint_file, get_checkpoint_best())
else:
with pathmgr.open(get_checkpoint_best(), "rb") as f:
best = torch.load(f, map_location="cpu")
# Select the best model weights and the best model_ema weights
if test_err < best["test_err"] or ema_err < best["ema_err"]:
if test_err < best["test_err"]:
best["model_state"] = checkpoint["model_state"]
best["test_err"] = test_err
if ema_err < best["ema_err"]:
best["ema_state"] = checkpoint["ema_state"]
best["ema_err"] = ema_err
with pathmgr.open(get_checkpoint_best(), "wb") as f:
torch.save(best, f)
return checkpoint_file