in picotron/checkpoint.py [0:0]
def load_checkpoint(self, model, optimizer, out_dir):
"""Load the model/optimizer states from the latest checkpoint. Assume the topology is the same."""
path = self._get_checkpoint_path(out_dir)
if not os.path.exists(path):
raise FileNotFoundError(f"Checkpoint not found at {path}")
checkpoint = torch.load(path)
# Load model weights
raw_model = model.module if self.cp_dp_world_size > 1 else model
raw_model.load_state_dict(checkpoint['model'])
# Load optimizer state
optimizer.load_state_dict(checkpoint['optimizer'])
return checkpoint['trained_steps'], checkpoint['trained_tokens']