chatlearn/models/fsdp_module.py (339 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """FSDP module""" import os import random import functools import gc import ray import numpy as np import torch import torch.distributed as dist from torch import optim, nn from torch.distributed.fsdp import (MixedPrecision, ShardingStrategy, ShardedOptimStateDictConfig, ShardedStateDictConfig, FullStateDictConfig, StateDictType, FullyShardedDataParallel as FSDP) from torch.distributed.fsdp.wrap import (size_based_auto_wrap_policy, transformer_auto_wrap_policy, lambda_auto_wrap_policy, _or_policy) from torch.distributed.fsdp._runtime_utils import _lazy_init from torch.multiprocessing.reductions import reduce_tensor from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.trainer_pt_utils import get_module_class_from_name from chatlearn.utils.logger import debug_rank_0, log_rank_0 from chatlearn.utils.utils import dict_to_simplenamespace from .torch_module import TorchModule class FSDPModule(TorchModule): """TorchModule is the class for Alignment Torch models. Args ---- name : str model name """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not self.trainable: # inference only if ( self.model_args.get("train_micro_batch_size") != self.module_args.generation_batch_size ): self._logger.info( f"{self.name} Overwrite train_micro_batch_size with generation_batch_size {self.module_args.generation_batch_size}" ) self.train_micro_batch_size = self.module_args.generation_batch_size else: self.train_micro_batch_size = self.runtime_args.train_micro_batch_size self.train_global_batch_size = self.runtime_args.train_global_batch_size self.fsdp_size = self.module_args.fsdp_size self.device_mesh = None def get_visible_gpus(self): """ :meta private: """ return ray.get_gpu_ids() @staticmethod def init_fn(x: torch.nn.Module): if torch.distributed.get_rank() != 0: x = x.to_empty(device=torch.cuda.current_device(), recurse=False) torch.cuda.empty_cache() return x @staticmethod def get_fsdp_wrap_policy(module:torch.nn.Module, min_num_params:int=0): """Get FSDP wrap policy for the module. Args: module: The module to get wrap policy for min_num_params: size based wrap policy min num params """ default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None) fsdp_transformer_layer_cls_to_wrap = default_transformer_cls_names_to_wrap auto_wrap_policy = None policies = [] # Add lambda policy for LoRA modules if is_lora is True # min_num_params must be 0 to use transformer_auto_wrap_policy if min_num_params > 0: size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) policies.append(size_policy) elif fsdp_transformer_layer_cls_to_wrap is not None: transformer_cls_to_wrap = set() for layer_class in fsdp_transformer_layer_cls_to_wrap: transformer_cls = get_module_class_from_name(module, layer_class) assert transformer_cls is not None, "Could not find the transformer layer class to wrap in the model." transformer_cls_to_wrap.add(transformer_cls) transformer_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap, ) policies.append(transformer_policy) ### hardcode for qwen2.5, fsdp warp for get submodule state dict ### def lambda_fn(sub_module: nn.Module): if sub_module in [module.model.embed_tokens, module.model.norm, module.lm_head]: return True return False lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_fn) policies.append(lambda_policy) ## hardcode for qwen2.5, fsdp warp for get submodule state dict ## if len(policies) > 0: auto_wrap_policy = functools.partial(_or_policy, policies=policies) return auto_wrap_policy def create_device_mesh(self, world_size, fsdp_size): if not self.device_mesh: if world_size == fsdp_size: self.device_mesh = dist.device_mesh.init_device_mesh( "cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] ) else: self.device_mesh = dist.device_mesh.init_device_mesh( "cuda", mesh_shape=(fsdp_size, world_size // fsdp_size), mesh_dim_names=["fsdp", "ddp"], ) print(f"world size {world_size}, fsdp_size {fsdp_size}, {self.device_mesh}") def setup_distributed(self): print(self.get_dist_env()) if not dist.is_initialized(): dist.init_process_group(backend="nccl") self.create_device_mesh(self.world_size, self.fsdp_size) def peak_memory(self): """ :meta private: """ self._peak_memory = max( self._peak_memory, torch.cuda.max_memory_allocated() / (1024**3) ) return self._peak_memory def empty_cache(self): """ :meta private: """ if not self.timers("empty_cache").started_: self.timers("empty_cache").start() peak_mem = torch.cuda.max_memory_allocated() / (1024**3) debug_rank_0( f"{self.name} replica: {self.replica_id}, before empty cache, peak mem: {peak_mem:.2f} GiB", self._logger, ) # Manual gc gc.collect() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() peak_mem = torch.cuda.max_memory_allocated() / (1024**3) debug_rank_0( f"{self.name} replica: {self.replica_id}, after empty cache, peak mem: {peak_mem:.2f} GiB", self._logger, ) self.timers("empty_cache").stop() def check_param_exists(self, names): """ check if the given names exists in current model :meta private: """ not_exists = [] for name in names: if not self.exist_parameter(name): not_exists.append(name) if not_exists: log_rank_0( f"parameters not exists: {not_exists} in model {self.name}", self._logger, ) return False return True def create_model(self, model_path, torch_dtype): model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=model_path, torch_dtype=torch_dtype, attn_implementation="flash_attention_2", trust_remote_code=True, ) return model @property def data_parallel_size(self): """ :meta private: """ return dist.get_world_size() @property def data_parallel_rank(self): """ :meta private: """ return dist.get_rank() def tensor_parallel_rank(self): return self.data_parallel_rank def pipeline_parallel_rank(self): return 1 def expert_model_parallel_size(self): return 1 def model_setup(self): """ :meta private: """ super().model_setup() self.setup_distributed() args = dict_to_simplenamespace(self.model_args) self.args = args model = self.create_model(args.pretrain_or_model, torch_dtype=torch.bfloat16) self.tokenizer = AutoTokenizer.from_pretrained( args.pretrain_or_model, trust_remote_code=True, use_fast=True ) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) mix_precision_config = MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32, ) sharding_strategy = ShardingStrategy.FULL_SHARD auto_wrap_policy = self.get_fsdp_wrap_policy(module = model) self.model = FSDP( model, cpu_offload=None, auto_wrap_policy=auto_wrap_policy, device_id=torch.cuda.current_device(), sharding_strategy=sharding_strategy, mixed_precision=mix_precision_config, sync_module_states=True, param_init_fn=FSDPModule.init_fn, device_mesh=self.device_mesh, forward_prefetch=False, ) self.model.to(torch.float32) FSDP.set_state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT) if not self.trainable: self.optimizer = None self.model.eval() else: self.optimizer = optim.AdamW( self.model.parameters(), lr=self.module_args.args_dict.get("learning_rate", 2e-6), betas=(0.9, 0.999), weight_decay=1e-2 ) # resume model weights if self.resume_training: self.load_checkpoint(self._episode_id) self.offload() def get_fsdp_param_name(self): name_list = [] for name, _ in self.model.named_parameters(): parts = name.split('.') filtered_parts = [ part for part in parts if part not in {"_fsdp_wrapped_module", "_flat_param"} ] cleaned_name = '.'.join(filtered_parts) name_list.append(cleaned_name) return name_list def get_weight_ipc_handles_by_name(self, block_name: str): """ get fsdp warpped module weight by name get from named_parameters avoid get total model state_dict """ torch.cuda.empty_cache() for prefix_name, module in self.model.named_modules(): prefix_name = prefix_name.replace('_fsdp_wrapped_module.', '') if isinstance(module, FSDP) and prefix_name==block_name: state_dict = module.state_dict() reduce_tensor_dict = {} for name, param in state_dict.items(): reduce_tensor_dict['.'.join([prefix_name, name])] = reduce_tensor(param.full_tensor()) return reduce_tensor_dict @torch.no_grad() def onload_weights(self,empty_cache=True): _lazy_init(self.model, self.model) assert self.model._is_root device_id = torch.cuda.current_device() for handle in self.model._all_handles: if handle._offload_params: continue flat_param = handle.flat_param handle.flat_param_to(torch.device(f"cuda:{device_id}"), non_blocking=True) # the following still keeps id(._local_shard) != id(.data) flat_param._local_shard = flat_param.data if empty_cache: torch.cuda.empty_cache() @torch.no_grad() def offload_weights(self, empty_cache=True): assert isinstance(self.model, FSDP) # lazy init FSDP model _lazy_init(self.model, self.model) assert self.model._is_root, "Only support root model offloading to CPU" for handle in self.model._all_handles: if handle._offload_params: continue flat_param = handle.flat_param assert ( flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() and id(flat_param.data) != id(flat_param._local_shard) and flat_param.data.size() == flat_param._local_shard.size() ) handle.flat_param_to(torch.device("cpu"), non_blocking=True) # Explicit call to free unshard flat param handle._free_unsharded_flat_param() # the following still keeps id(._local_shard) != id(.data) flat_param._local_shard = flat_param.data assert id(flat_param._local_shard) != id(flat_param.data) # Explicit releas ipc handles torch.cuda.ipc_collect() if empty_cache: torch.cuda.empty_cache() @torch.no_grad() def offload_optimizer_states(self, empty_cache=True): if not self.optimizer.state: return for param_group in self.optimizer.param_groups: for param in param_group["params"]: state = self.optimizer.state[param] for key, value in state.items(): if isinstance(value, torch.Tensor): state[key] = value.to("cpu", non_blocking=True) if empty_cache: torch.cuda.empty_cache() @torch.no_grad() def onload_optimizer_states(self, empty_cache=True): if not self.optimizer.state: return device = torch.cuda.current_device() for param_group in self.optimizer.param_groups: for param in param_group["params"]: state = self.optimizer.state[param] for key, value in state.items(): if isinstance(value, torch.Tensor): state[key] = value.to(device, non_blocking=True) if empty_cache: torch.cuda.empty_cache() def save_checkpoint(self, iteration): save_dir = f"{self.runtime_args.output_dir}/save_model/{self.name}/{iteration}" if self.data_parallel_rank == 0 and not os.path.exists(save_dir): os.makedirs(save_dir) state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): model_state_dict = self.model.state_dict() optimizer_state_dict = self.optimizer.state_dict() if self.optimizer is not None else None # lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None extra_state_dict = { # "lr_scheduler": lr_scheduler_state_dict, "rng": self.get_rng_state(), } model_path = os.path.join(save_dir, f"model_world_size_{self.data_parallel_size}_rank_{self.data_parallel_rank}.pt") optim_path = os.path.join(save_dir, f"optim_world_size_{self.data_parallel_size}_rank_{self.data_parallel_rank}.pt") extra_path = os.path.join(save_dir, f"extra_state_world_size_{self.data_parallel_size}_rank_{self.data_parallel_rank}.pt") torch.save(model_state_dict, model_path) torch.save(optimizer_state_dict, optim_path) torch.save(extra_state_dict, extra_path) torch.distributed.barrier() # save for hf format if self.model_args.get("save_hf", True): state_dict_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, state_dict_cfg, None): model_state_dict = self.model.state_dict() if self.data_parallel_rank == 0: hf_path = os.path.join(save_dir, "huggingface") os.makedirs(hf_path, exist_ok=True) model_config = self.model._fsdp_wrapped_module.config model_config.save_pretrained(hf_path) self.tokenizer.save_pretrained(hf_path) with torch.device("meta"): save_model = AutoModelForCausalLM.from_config(model_config, torch_dtype=torch.bfloat16) save_model.to_empty(device="cpu") save_model.save_pretrained(hf_path, state_dict=model_state_dict) self._logger.info(f"save checkpoint to {save_dir}") def load_checkpoint(self, iteration): load_dir = f"{self.runtime_args.output_dir}/save_model/{self.name}/{iteration}" if not os.path.exists(load_dir): self._logger.info(f"{load_dir} not exists, will skip load") return model_path = os.path.join(load_dir, f"model_world_size_{self.data_parallel_size}_rank_{self.data_parallel_rank}.pt") optim_path = os.path.join(load_dir, f"optim_world_size_{self.data_parallel_size}_rank_{self.data_parallel_rank}.pt") extra_state_path = os.path.join(load_dir, f"extra_state_world_size_{self.data_parallel_size}_rank_{self.data_parallel_rank}.pt") model_state_dict = torch.load(model_path, weights_only=False) optimizer_state_dict = torch.load(optim_path, weights_only=False) extra_state_dict = torch.load(extra_state_path, weights_only=False) state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): self.model.load_state_dict(model_state_dict) if self.optimizer is not None: self.optimizer.load_state_dict(optimizer_state_dict) # recover random state if "rng" in extra_state_dict: # 'rng' may not exist for backward compatibility self.load_rng_state(extra_state_dict["rng"]) torch.distributed.barrier() @staticmethod def get_rng_state(): rng_state = { "cpu": torch.get_rng_state(), "cuda": torch.cuda.get_rng_state(), "numpy": np.random.get_state(), "random": random.getstate(), } return rng_state @staticmethod def load_rng_state(rng_state): torch.set_rng_state(rng_state["cpu"]) torch.cuda.set_rng_state(rng_state["cuda"]) np.random.set_state(rng_state["numpy"]) random.setstate(rng_state["random"])