in picotron/checkpoint.py [0:0]
def save_checkpoint(self, model, optimizer, trained_steps, trained_tokens, out_dir):
"""Save the model/optimizer states/steps to a checkpoint file."""
path = self._get_checkpoint_path(out_dir)
# Only DP/CP rank 0 will save the model, the weights are the same across all ranks
if self.dp_rank == 0 and self.cp_rank == 0:
os.makedirs(out_dir, exist_ok=True)
raw_model = model.module if self.cp_dp_world_size > 1 else model
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'trained_steps': trained_steps,
'trained_tokens': trained_tokens
}
torch.save(checkpoint, path)