tzrec/utils/checkpoint_util.py (150 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import glob import os from dataclasses import replace from typing import List, Optional, Tuple from torch import nn, optim from torch.distributed.checkpoint import ( FileSystemReader, TensorStorageMetadata, load, save, ) from torch.distributed.checkpoint.default_planner import ( DefaultLoadPlanner, DTensor, LoadPlan, _create_read_items, ) from tzrec.utils.logging_util import logger class PartialLoadPlanner(DefaultLoadPlanner): """Support restore partial states. Args: flatten_state_dict (bool): Handle state_dict with nested dicts. flatten_sharded_tensors (bool): For FSDP in 2D parallel mode. ckpt_param_map_path (str): parameter mapping for checkpoint. """ def __init__( self, flatten_state_dict: bool = True, flatten_sharded_tensors: bool = True, ckpt_param_map_path: Optional[str] = None, ) -> None: super().__init__(flatten_state_dict, flatten_sharded_tensors) self._ckpt_param_map = dict() if ckpt_param_map_path: with open(ckpt_param_map_path) as f: for line in f.readlines(): cur_param_name, old_param_name = line.strip().split() self._ckpt_param_map[cur_param_name] = old_param_name def create_local_plan(self) -> LoadPlan: """Create local load plan.""" requests = [] # pyre-ignore [16] for fqn, obj in self.state_dict.items(): meta_fqn = fqn if fqn in self._ckpt_param_map: meta_fqn = self._ckpt_param_map[fqn] logger.info(f"Remap restore state [{fqn}] from [{meta_fqn}]") # pyre-ignore [16] if meta_fqn in self.metadata.state_dict_metadata: md = self.metadata.state_dict_metadata[meta_fqn] else: logger.warning(f"Skip restore state [{fqn}]") continue read_items = [] if isinstance(obj, DTensor): if obj.device_mesh.get_coordinate() is not None: read_items = _create_read_items(meta_fqn, md, obj) else: read_items = _create_read_items(meta_fqn, md, obj) if fqn in self._ckpt_param_map: read_items = [ replace(x, dest_index=replace(x.dest_index, fqn=fqn)) for x in read_items ] requests += read_items plan = LoadPlan(requests) return plan def _get_checkpoint_step(ckpt_path: str) -> int: """Get checkpoint step from ckpt_path. Args: ckpt_path: checkpoint path, such as xx/model.ckpt-2000. Return: ckpt_step: checkpoint step, such as 2000. """ _, ckpt_name = os.path.split(ckpt_path) ckpt_name, ext = os.path.splitext(ckpt_name) if ext.startswith(".ckpt-"): ckpt_name = ext toks = ckpt_name.split("-") try: ckpt_step = int(toks[-1]) except Exception: ckpt_step = 0 return ckpt_step def latest_checkpoint(model_dir: str) -> Tuple[Optional[str], int]: """Find latest checkpoint under a directory. Args: model_dir: model directory Return: latest_ckpt_path: latest checkpoint path. latest_step: step of the latest checkpoint """ if "model.ckpt-" not in model_dir: ckpt_metas = glob.glob(os.path.join(model_dir, "model.ckpt-*")) if len(ckpt_metas) == 0: model_ckpt_dir = os.path.join(model_dir, "model") optim_ckpt_dir = os.path.join(model_dir, "optimizer") if os.path.exists(model_ckpt_dir) or os.path.exists(optim_ckpt_dir): return model_dir, 0 else: return None, -1 if len(ckpt_metas) > 1: ckpt_metas.sort(key=lambda x: _get_checkpoint_step(x)) latest_ckpt_path = ckpt_metas[-1] else: latest_ckpt_path = model_dir return latest_ckpt_path, _get_checkpoint_step(latest_ckpt_path) def restore_model( checkpoint_dir: str, model: nn.Module, optimizer: Optional[optim.Optimizer] = None, ckpt_param_map_path: Optional[str] = None, ) -> None: """Restore model state. Args: checkpoint_dir (str): easyrec model checkpoint dir. model (nn.Module): a EasyRec model. optimizer (optim.Optimizer, optional): a optimizer. ckpt_param_map_path (str): parameter mapping for checkpoint. """ is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0 if is_local_rank_zero: logger.info(f"Restoring checkpoint from {checkpoint_dir}...") if not os.path.exists(checkpoint_dir): raise RuntimeError(f"checkpoint_dir[{checkpoint_dir}] not exists.") model_ckpt_path = os.path.join(checkpoint_dir, "model") optim_ckpt_path = os.path.join(checkpoint_dir, "optimizer") if os.path.exists(model_ckpt_path): if is_local_rank_zero: logger.info(f"Restoring model state from {model_ckpt_path}...") state_dict = model.state_dict() load( state_dict, checkpoint_id=model_ckpt_path, planner=PartialLoadPlanner(ckpt_param_map_path=ckpt_param_map_path), ) model.load_state_dict(state_dict) if optimizer and os.path.exists(optim_ckpt_path): if is_local_rank_zero: logger.info(f"Restoring optimizer state from {optim_ckpt_path}...") state_dict = optimizer.state_dict() load( state_dict, checkpoint_id=optim_ckpt_path, planner=PartialLoadPlanner(ckpt_param_map_path=ckpt_param_map_path), ) optimizer.load_state_dict(state_dict) def save_model( checkpoint_dir: str, model: nn.Module, optimizer: Optional[optim.Optimizer] = None ) -> None: """Save model state. Args: checkpoint_dir (str): easyrec model checkpoint dir. model (nn.Module): a EasyRec model. optimizer (optim.Optimizer, optional): a optimizer. """ if int(os.environ.get("LOCAL_RANK", 0)) == 0: logger.info(f"Saving checkpoint to {checkpoint_dir}...") save(model.state_dict(), checkpoint_id=os.path.join(checkpoint_dir, "model")) if optimizer: save( optimizer.state_dict(), checkpoint_id=os.path.join(checkpoint_dir, "optimizer"), ) def list_distcp_param(checkpoint_dir: str) -> List[str]: """List.""" meta_paths = [] if os.path.exists(os.path.join(checkpoint_dir, ".metadata")): meta_paths.append(checkpoint_dir) else: if os.path.exists(os.path.join(checkpoint_dir, "model", ".metadata")): meta_paths.append(os.path.join(checkpoint_dir, "model")) if os.path.exists(os.path.join(checkpoint_dir, "optimizer", ".metadata")): meta_paths.append(os.path.join(checkpoint_dir, "optimizer")) if len(meta_paths) == 0: raise RuntimeError(f"Can't find distribute checkpoint in {checkpoint_dir}") param_names = [] for meta_path in meta_paths: reader = FileSystemReader(path=meta_path) meta = reader.read_metadata() logger.info(f"Params in {meta_path}:") for k, v in meta.state_dict_metadata.items(): if isinstance(v, TensorStorageMetadata): param_names.append(k) logger.info(f"{k}: {v.size}") return param_names