def save_checkpoint()

in picotron/checkpoint.py [0:0]


    def save_checkpoint(self, model, optimizer, trained_steps, trained_tokens, out_dir):
        """Save the model/optimizer states/steps to a checkpoint file."""
        path = self._get_checkpoint_path(out_dir)
        
        # Only DP/CP rank 0 will save the model, the weights are the same across all ranks
        if self.dp_rank == 0 and self.cp_rank == 0:
            os.makedirs(out_dir, exist_ok=True)
            raw_model = model.module if self.cp_dp_world_size > 1 else model
            checkpoint = {
                'model': raw_model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'trained_steps': trained_steps,
                'trained_tokens': trained_tokens
            }
            torch.save(checkpoint, path)