chatlearn/models/megatron_module.py (193 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Megatron module""" import inspect import re import torch import torch.distributed as dist try: from chatlearn.utils.megatron_import_helper import get_args from chatlearn.utils.megatron_import_helper import mpu from chatlearn.utils.megatron_import_helper import initialize_megatron from chatlearn.utils.megatron_import_helper import save_checkpoint_and_time from chatlearn.utils.megatron_import_helper import set_jit_fusion_options from chatlearn.utils.megatron_utils import initialize_megatron as chatlearn_initialize_megatron from chatlearn.utils.megatron_utils import build_pipeline_layer_name_mapping from chatlearn.models.megatron.memory_manager import create_trainer_memory_manager, InferenceMemoryManager except ImportError: mpu = None from .torch_module import TorchModule # pylint: disable=import-outside-toplevel class MegatronModule(TorchModule): """MegatronModule is the class for Alignment Megatron models. Args ---- name : str model name """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if mpu is None: print("Cannot import megatron, please set megatron python path first.") if not self.trainable: # inference only if self.model_args.get("micro_batch_size") != self.module_args.generation_batch_size: self._logger.info(f"{self.name} Overwrite micro_batch_size with generation_batch_size {self.module_args.generation_batch_size}") self.model_args["micro_batch_size"] = self.module_args.generation_batch_size else: self.model_args["micro_batch_size"] = self.runtime_args.train_micro_batch_size self.model_args["global_batch_size"] = self.runtime_args.train_global_batch_size if self.model_args.get("micro_batch_size") != self.runtime_args.train_micro_batch_size: self._logger.info(f"{self.name} Overwrite micro_batch_size with train_micro_batch_size {self.module_args.train_micro_batch_size}") if self.model_args.get("global_batch_size") != self.runtime_args.train_global_batch_size: self._logger.info(f"{self.name} Overwrite global_batch_size with train_global_batch_size {self.module_args.train_global_batch_size}") if not self.model_args.get("tensorboard_dir") and self.runtime_args.output_dir is not None: self.model_args['tensorboard_dir'] = f"{self.runtime_args.output_dir}/tensorboard" def add_extra_args(self, parser): """ Add extra arguments for megatron. Args ---- parser : ArgumentParser Add extra arguments. """ return parser def init(self): """ :meta private: """ if "args_dict" in inspect.getfullargspec(initialize_megatron).args: initialize_func = initialize_megatron else: initialize_func = chatlearn_initialize_megatron initialize_func(extra_args_provider=self.add_extra_args, ignore_unknown_args=True, args_dict=self.model_args) if self.trainable: # slow down if set jit fusion for inference model set_jit_fusion_options() def model_setup(self): """ :meta private: """ super().model_setup() # TODO: we may need to let setup return model, optimizer and opt_param_scheduler if self.trainable: assert hasattr(self, "model") assert hasattr(self, "optimizer") assert hasattr(self, "opt_param_scheduler") if self.module_args.offload_weights or self.module_args.free_grad_buffers or self.module_args.offload_optimizer_states: self._memory_manager = create_trainer_memory_manager( self.megatron_model(), self.optimizer, self.megatron_args.use_distributed_optimizer, self.megatron_args.accumulate_allreduce_grads_in_fp32, self.megatron_args.params_dtype, self.runtime_args.bucket_size_mb_in_memory_manager, ) self.offload() else: assert hasattr(self, "model") self.model.eval() if self.module_args.offload_weights: self._memory_manager = InferenceMemoryManager( self.megatron_model(), self.runtime_args.bucket_size_mb_in_memory_manager, ) self.offload() self.set_pipe_layer_num_offset() def set_pipe_layer_num_offset(self): self.stage2layer_num = [None] * self.pipeline_model_parallel_size() self.stage2offset = [0] * self.pipeline_model_parallel_size() stage_layer_num = self.get_pipeline_stage_layer_num() world_size = torch.distributed.get_world_size() rank_layer_num = torch.tensor([self.pipeline_parallel_rank(), stage_layer_num], device='cuda') # Gather all tensors to all processes all_stage_layer_nums = [torch.zeros_like(rank_layer_num, device='cuda') for _ in range(world_size)] torch.distributed.all_gather(all_stage_layer_nums, rank_layer_num) for item in all_stage_layer_nums: rank = item[0].item() num = item[1].item() if self.stage2layer_num[rank] is None: self.stage2layer_num[rank] = num else: assert self.stage2layer_num[rank] == num for i, num in enumerate(self.stage2layer_num): if i+1 == len(self.stage2offset): break self.stage2offset[i+1] = self.stage2offset[i] + num @property def megatron_args(self): """ :meta private: """ return get_args() def pipeline_model_parallel_size(self): """ get pipeline_model_parallel_size :meta private: """ return self.megatron_args.pipeline_model_parallel_size def tensor_model_parallel_size(self): """ get tensor_model_parallel_size :meta private: """ return self.megatron_args.tensor_model_parallel_size def expert_model_parallel_size(self): """ get expert_model_parallel_size :meta private: """ if hasattr(self.megatron_args, "expert_model_parallel_size"): return self.megatron_args.expert_model_parallel_size if hasattr(self.megatron_args, "moe_expert_model_parallel_size"): return self.megatron_args.moe_expert_model_parallel_size return 1 def tensor_and_expert_model_parallel_size(self): """ get tensor_and_expert_model_parallel_size :meta private: """ return self.megatron_args.tensor_model_parallel_size * self.expert_model_parallel_size() @property def data_parallel_size(self): """ :meta private: """ return mpu.get_data_parallel_world_size() @property def data_parallel_rank(self): """ :meta private: """ return mpu.get_data_parallel_rank() def pipeline_parallel_rank(self): """ :meta private: """ return mpu.get_pipeline_model_parallel_rank() def tensor_parallel_rank(self): """ :meta private: """ return mpu.get_tensor_model_parallel_rank() def tensor_and_expert_parallel_group(self): """ :meta private: """ return mpu.get_tensor_and_expert_parallel_group() def expert_parallel_rank(self): """ :meta private: """ if hasattr(mpu, "get_expert_model_parallel_rank"): return mpu.get_expert_model_parallel_rank() return 0 def num_layers(self): """ :meta private: """ return self.megatron_args.num_layers def megatron_model(self): if isinstance(self.model, list): assert len(self.model) == 1 model = self.model[0] else: model = self.model return model def build_pipeline_layer_name_mapping(self, num_target_pipe_stage, target_pipe_rank, tgt_layer_offset, requires_grad=True): """ build name mapping from src model to tgt model Args: num_target_pipe_stage: number of pipeline stage in target model target_pipe_rank: target model pipeline rank tgt_layer_offset: target model pipeline stage layer offset requires_grad: whether the returned layer requires_grad, as we only need to sync parameters that have changed :meta private: """ src_layer_offset = self.get_pipeline_stage_layer_offset() model = self.megatron_model() is_tgt_last_stage = target_pipe_rank == num_target_pipe_stage - 1 and target_pipe_rank != 0 name_mapping = build_pipeline_layer_name_mapping(src_layer_offset, tgt_layer_offset, is_tgt_last_stage, model, requires_grad) return name_mapping def get_local_param_ranks(self): """ :meta private: """ if self.expert_model_parallel_size() == 1: data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS) return data_parallel_global_ranks, mpu.get_data_parallel_rank() else: # Get data parallel modulo expert parallel ranks data_modulo_expert_parallel_group = mpu.get_data_modulo_expert_parallel_group() data_modulo_expert_parallel_ranks = dist.get_process_group_ranks(data_modulo_expert_parallel_group) return data_modulo_expert_parallel_ranks, mpu.get_data_modulo_expert_parallel_rank() def save_checkpoint(self, iteration): """ save checkpoint at `iteration` :param iteration: save iteration :meta private: """ if self.enable_lora: self.fuse_lora_layer() save_checkpoint_and_time(iteration, self.model, self.optimizer, self.opt_param_scheduler) if self.enable_lora: self.unfuse_lora_layer() def offload_optimizer_states(self): """ offload optimizer states """ if self.module_args.offload_optimizer_states: self._memory_manager.offload_optimizer_states() def onload_optimizer_states(self): """ onload optimizer states """ if self.module_args.offload_optimizer_states: self._memory_manager.onload_optimizer_states() def offload_main_weights(self): """ offload main weights """ if self.module_args.offload_weights: self._memory_manager.offload_main_weights() def onload_main_weights(self): """ onload main weights """ if self.module_args.offload_weights: self._memory_manager.onload_main_weights() def offload_weights(self): """ offload weights """ if self.module_args.offload_weights: self._memory_manager.offload_weights() def onload_weights(self): """ onload weights """ if self.module_args.offload_weights: self._memory_manager.onload_weights() def free_grad_buffers(self): """ free grad buffers and related tensors """ if self.module_args.free_grad_buffers: self._memory_manager.free_grad_buffers() def build_grad_buffers(self): """ build grad buffers and related tensors """ if self.module_args.free_grad_buffers: self._memory_manager.build_grad_buffers() def get_pipeline_stage_layer_num(self): assert self.stage2layer_num is not None if self.stage2layer_num[self.pipeline_parallel_rank()] is not None: return self.stage2layer_num[self.pipeline_parallel_rank()] layer_re = re.compile(r'layers\.([0-9]+)') layer_set = set() for name in self.named_parameters: layer_num = re.findall(layer_re, name) if layer_num: layer_set.add(layer_num[0]) stage_layer_num = len(layer_set) return stage_layer_num def get_pipeline_stage_layer_offset(self): assert self.stage2offset is not None and \ self.stage2offset[self.pipeline_parallel_rank()] is not None return self.stage2offset[self.pipeline_parallel_rank()]