chatlearn/checkpoint/checkpoint_manager.py (110 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Checkpoint Manager.""" import os import pickle import shutil from chatlearn.utils.logger import log_rank_0 def path_exists(path): """path exits""" return path and os.path.exists(path) class CheckpointManager: """Checkpoint Manager""" def __init__(self, model, path, max_ckpt_nums, load_iteration=None, config_to_check=None): self._path = path self._max_ckpt_nums = max_ckpt_nums self._meta_file = os.path.join(self._path, "latest_checkpointed_iteration.txt") if not path_exists(path): os.makedirs(path, exist_ok=True) self.load_iteration = load_iteration self._meta = None self._model = model self._resumed = False self._config_to_check = {} if config_to_check is None else config_to_check def _get_checkpoint_path_name(self, replica_id, step): name = "replica{}_step{}".format(replica_id, step) ckpt_path = os.path.join(self._path, name) return ckpt_path def _make_checkpoint_path(self, replica_id, step): ckpt_path = self._get_checkpoint_path_name(replica_id, step) if not path_exists(ckpt_path): os.makedirs(ckpt_path, exist_ok=True) return ckpt_path def _delete_ckpt_files(self): """Delete checkpoint files.""" ckpt_path = self._path ckpt_files = os.listdir(ckpt_path) ckpt_folders = [(os.path.join(ckpt_path, f), os.path.getmtime(os.path.join(ckpt_path, f))) for f in ckpt_files if os.path.isdir(os.path.join(ckpt_path, f))] ckpt_folders.sort(key=lambda x: x[1], reverse=True) ckpt_folders = [f[0] for f in ckpt_folders] reserved_folders = ckpt_folders[:self._max_ckpt_nums] for folder in ckpt_folders: if os.path.isdir(folder): if folder in reserved_folders: continue try: shutil.rmtree(folder) except PermissionError: log_rank_0("Permission Denied: Please check the checkpoint file permissions.") def save_checkpoint(self, replica_id, train_iter, episode, consumed_samples): """save data checkpoint""" ckpt_path = self._make_checkpoint_path(replica_id, train_iter) log_rank_0( f"save data checkpoint to {ckpt_path}, replica: {replica_id}, train_iter: {train_iter}, episode: {episode} " + \ f"consumed samples {consumed_samples}") def _get_path(fn): return os.path.join(ckpt_path, fn) meta_data = {"episode": episode, "train_iteration": train_iter, "consumed_samples": consumed_samples, "sample_per_episode": self._model.runtime_args.sample_per_episode, "data_ratio": self._model.runtime_args.data_ratio} with open(_get_path("meta.pkl"), 'wb') as f: pickle.dump(meta_data, f) if replica_id == 0: self._set_latest_iteration(train_iter) # only reserve max nums of ckpt folders if needed if isinstance(self._max_ckpt_nums, int): self._delete_ckpt_files() log_rank_0("Checkpointing is done.") return True def _set_latest_iteration(self, iteration): with open(self._meta_file, 'w', encoding='utf-8') as f: f.write(f"{iteration}") def _get_checkpoint_path(self): """Get checkpoint path.""" latest_iter = self._get_latest_iteration() if latest_iter is None: log_rank_0(f"{self._meta_file} not found or load_iteration is not provided") return ckpt_path = self._get_checkpoint_path_name(self._model.replica_id, latest_iter) if path_exists(ckpt_path): log_rank_0(f"get checkpoint path from {self._path}") return ckpt_path log_rank_0(f"checkpoint path {ckpt_path} not exists") return def validate(self, ckpt_meta): for key, value in self._config_to_check.items(): assert value == ckpt_meta[key], \ f"config {key}: {value} diff with ckpt config {ckpt_meta[key]}" def resume_meta(self): if self._meta is not None: return self._meta ckpt_dir = self._get_checkpoint_path() if ckpt_dir is None: return with open(os.path.join(ckpt_dir, "meta.pkl"), 'rb') as f: self._meta = pickle.load(f) self.validate(self._meta) return self._meta def resume(self): """Resume data structures.""" if self._resumed: return self._meta meta = self.resume_meta() if meta is not None: self._model.runtime_args.consumed_samples = meta["consumed_samples"] log_rank_0(f"set consumed_samples to {meta['consumed_samples']}") self._model.runtime_args.data_ratio = data_ratio = meta.get("data_ratio", None) log_rank_0(f"set data_ratio to {data_ratio}") self._resumed = True return meta def _get_latest_iteration(self): if self.load_iteration is not None: return self.load_iteration if not path_exists(self._meta_file): return with open(self._meta_file, encoding='utf-8') as f: iteration = f.read().strip() return iteration