def load_checkpoint()

in picotron/checkpoint.py [0:0]


    def load_checkpoint(self, model, optimizer, out_dir):
        """Load the model/optimizer states from the latest checkpoint. Assume the topology is the same."""
        path = self._get_checkpoint_path(out_dir)
        
        if not os.path.exists(path):
            raise FileNotFoundError(f"Checkpoint not found at {path}")
            
        checkpoint = torch.load(path)

        # Load model weights
        raw_model = model.module if self.cp_dp_world_size > 1 else model
        raw_model.load_state_dict(checkpoint['model'])
        
        # Load optimizer state
        optimizer.load_state_dict(checkpoint['optimizer'])
        
        return checkpoint['trained_steps'], checkpoint['trained_tokens']