in src/nanotron/trainer.py [0:0]
def save_checkpoint(self) -> Path:
self.pre_save_checkpoint()
checkpoints_path = self.config.checkpoints.checkpoints_path
checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}"
if self.config.checkpoints.checkpoints_path_is_shared_file_system:
should_mkdir = dist.get_rank(self.parallel_context.world_pg) == 0
else:
should_mkdir = bool(int(os.environ.get("LOCAL_RANK", None)) == 0)
if should_mkdir:
checkpoint_path.mkdir(parents=True, exist_ok=True)
dist.barrier(self.parallel_context.world_pg)
log_rank(f"Saving checkpoint at {checkpoint_path}", logger=logger, level=logging.WARNING, rank=0)
# Update step/samples numbers before we save the config
self.config.general.step = self.metadata.last_train_step
self.config.general.consumed_train_samples = self.metadata.consumed_train_samples # TODO: idc abt this
save(
model=self.unwrapped_model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
should_save_model=bool(
dist.get_rank(self.parallel_context.dp_cp_pg) == 0
), # We only save the weights on DP_CP==0
should_save_optimizer=True,
should_save_lr_scheduler=True,
should_save_config=bool(
dist.get_rank(self.parallel_context.world_pg) == 0
), # We only save the config on world_rank==0
parallel_context=self.parallel_context,
root_folder=checkpoint_path,
training_metadata=self.metadata,
config=self.config,
sanity_checks=not self.config.general.ignore_sanity_checks,
)
save_random_states(
random_states=self.random_states, parallel_context=self.parallel_context, root_folder=checkpoint_path
)
with open(checkpoints_path / "latest.txt", mode="w") as fo:
fo.write(f"{self.iteration_step}")
if hasattr(self.model_config, "to_json_file"):
self.model_config.to_json_file(checkpoint_path / MODEL_CONFIG_FILE_NAME)
else:
with open(checkpoint_path / MODEL_CONFIG_FILE_NAME, mode="w") as fo:
fo.write(json.dumps(asdict(self.model_config)))
self.post_save_checkpoint()
return checkpoint_path