chatlearn/models/base_module.py (754 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """base module""" from collections import defaultdict from itertools import cycle from pathlib import Path import math import time import os import torch import ray import ray.util.collective as col from ray.util.collective.collective_group.base_collective_group import BaseGroup from ray.util.collective.collective_group.nccl_collective_group import NCCLGroup from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from chatlearn.data.sampler import MultiDatasetSampler from chatlearn.data.data import RLHFDataLoader from chatlearn.checkpoint.checkpoint_manager import CheckpointManager from chatlearn.utils import future from chatlearn.utils.constant import LOG_START from chatlearn.utils.dist_utils import bucket_tensors, coalesced_comm_dense from chatlearn.utils.dist_utils import bucket_tensors_two_stage_generator, coalesced_comm_dense_two_stage from chatlearn.utils.global_vars import get_args from chatlearn.utils.global_vars import set_global_variables from chatlearn.utils.logger import logger from chatlearn.utils.logger import log_rank_0, debug_rank_0, setup_logger from chatlearn.utils.timer import Timers from chatlearn.utils.utils import get_host_addr, map_reduce_metrics from chatlearn.launcher import dlc_utils class BaseModule: """BaseModule is the base class for Base models. Args ---- name : str model name """ def __init__(self, name, args=None, replica_id=0): logger.info(f"{LOG_START} basemodule {name} init start") self.name = name if args is None: global_args = get_args() else: global_args = args set_global_variables(args) self.global_args = global_args args = global_args.models[name] self.total_gpu = args.num_gpu self.total_cpu = args.num_cpu self.gpu_per_process = args.gpu_per_process self.trainable = args.trainable self._runtime_args = self.global_args.runtime_args self._module_args = args self.replica_id = replica_id self.config_dir = args.config_dir self._is_colocate = False if self.total_gpu > 0: self._num_gpu_per_replica = ( args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.expert_model_parallel_size * args.zero_size * args.fsdp_size ) assert self._num_gpu_per_replica <= self.total_gpu, \ f"_num_gpu_per_replica {self._num_gpu_per_replica} larger than total_gpu {self.total_gpu} " + \ f"tp_size: {args.tensor_model_parallel_size} pp_size: {args.pipeline_model_parallel_size} " + \ f"ep_size: {args.expert_model_parallel_size} zero_size: {args.zero_size}" assert self.total_gpu % self._num_gpu_per_replica == 0 if not self.trainable: self._num_replica = args.num_gpu // self._num_gpu_per_replica else: # For trainable models, perform the DP inside DistActor self._num_replica = 1 self._num_gpu_per_replica = self.total_gpu else: self._num_gpu_per_replica = 0 self._num_replica = args.num_replica assert self._num_replica >= 1 self._param_ranks = None self._named_parameters = None self._param_to_name = None self._parameters = None self._coalesced_parameters = None self.error_signal = None self._rank = None self._world_size = None self._group_names = [] self._dataloader = None self._eval_dataloader = None self._kl_coef = None self._padding_config = {} self._storage = None self._timers = None self._data_iter = None self._eval_data_iter = None self.call_funcs = [] self.trainable_funcs = [] self._data_ckpt_manager = None self._peak_memory = 0 self._parameters_to_sync = defaultdict(list) self._parameters_to_send = defaultdict(list) self._parameters_to_recv = defaultdict(list) self._parameters_shape = [] # current compute iteration self._iteration = 0 self._train_iteration = 0 self._episode_id = 0 self.enable_lora = self._module_args.lora.enable_lora self._finalized = False self._resume_training = False self._address = dlc_utils.get_addr() if dlc_utils.in_dlc_env() else get_host_addr() self._is_master_node = os.environ.get("RANK", '0') == '0' self._logger = setup_logger(model_name=self.name, ip_addr=self._address) # parameter sync from src_model self._src_parameter_model = None self.profiler = None self._buffer_num = {} self._tp_division = {} self._tp_num_mapping = 1 self._sync_buffer = defaultdict(list) self._sync_dst_rank_to_src_ranks = {} self._expert_sync_buffer = {} self._synchronizer = None self._metric_prefix = "" self._metric_list = [] self._stage_resume_done = False logger.info(f"{LOG_START} basemodule {name} init done") def get_sync_buffer(self): return self._sync_buffer def set_tp_num_mapping(self, _tp_num_mapping): self._tp_num_mapping = _tp_num_mapping @property def tp_num_mapping(self): return self._tp_num_mapping def set_buffer_num(self, buffer_num): self._buffer_num.update(buffer_num) def get_buffer_num(self, param_names): return [self._buffer_num[name] for name in param_names] def set_tp_division(self, tp_division): self._tp_division.update(tp_division) def get_tp_division(self, param_names): return [self._tp_division[name] for name in param_names] @property def is_colocate(self): return self._is_colocate def set_colocate(self, flag): self._is_colocate = flag def finalize(self): """ finalize the class, any change from user after finalize will not work. :meta private: """ self._finalized = True def _assert_not_finalized(self): """ :meta private: """ assert not self._finalized, f"{self} is finalized, any change to the class should happen before finalize." def get_runtime_args(self): return self.runtime_args @property def runtime_args(self): """ Return the arguments related to alignment training, the settings that are specified under the "runtime" section of the YAML configuration file. """ return self._runtime_args @property def model_args(self): """ Return model arguments, such as those related to Megatron, should be specified in a separate configuration yaml file for the model being used. """ return self._module_args.args_dict @property def module_args(self): """ Return module arguments. module_args include `num_gpu`, `gpu_per_process`, `model_config_file`, etc. """ return self._module_args @property def parameter_sync_frequency(self): return self.module_args.sync_frequency def set_env(self, args): """ set system env, private :meta private: """ def set_error_signal(self, error_signal): """ signal for handling errors :meta private: """ self.error_signal = error_signal def error(self, error_msg=None): """ :meta private: """ future.wait(self.error_signal.set.remote(error_msg)) def init(self): """ Init env. """ def setup(self): """ Create model / optimizer / opt_param_scheduler / etc. """ @property def data_ckpt_manager(self): """ :meta private: """ if self.runtime_args.data_checkpoint_path is not None: assert self._data_ckpt_manager is not None return self._data_ckpt_manager def model_setup(self): """ :meta private: """ self.global_args.active_module_args = self._module_args if self.runtime_args.data_checkpoint_path is not None: self._data_ckpt_manager = CheckpointManager(self, self.runtime_args.data_checkpoint_path, self.runtime_args.max_data_ckpt_nums, self.runtime_args.load_data_checkpoint_iteration) if self.runtime_args.enable_resume_training: meta = self._data_ckpt_manager.resume() if meta: self._resume_training = self.runtime_args.consumed_samples > 0 start_episode = meta["episode"] + 1 self._episode_id = start_episode self._iteration = start_episode * math.ceil(self.runtime_args.sample_per_episode / \ self._num_replica / self.module_args.generation_batch_size) log_rank_0( f"{self.name} resume training {self._resume_training}: " f"set start iteration to {self._iteration} and episode id to {self._episode_id}", self._logger) self.setup() def forward_step(self, data, iteration): """ Perform forward step for one batch. Args ---- data : dict data for forward_step iteration : int local forward iteration Returns ------- Dict A dict of results, where key is the string type, and the value is the tensor or a list, where the first dim of tensor or the len of list equals to batch size """ def train_step(self, data, iteration): """ Perform train_step for one batch, including a list of micro-batches. Args ---- data : [Dict] A list of micro-batch for train_step, type of each micro-batch is dict iteration : int local train iteration """ def eval_step(self, data): """ Perform eval_step for one batch Args ---- data: Dict Data for eval_step. Returns ------- Dict A dict of results, where key is the string type, and the value is the tensor or a list, where the first dim of tensor or the len of list equals to batch size """ def save_checkpoint(self, iteration): """ Save checkpoint given iteration. Args ---- iteration: int Current training iteration """ def save_data_checkpoint(self, replica_id, iteration, episode_id): """ Save checkpoint for dataloader. :meta private: """ if self.data_ckpt_manager is not None: consumed_samples = self.runtime_args.consumed_samples self.data_ckpt_manager.save_checkpoint(replica_id, iteration, episode_id, consumed_samples) def put(self, key, data): """ Put the data to shared storage. Args ---- key: Str Use key to put. data data to save """ self._storage.put.remote(key, data) def get(self, key): """ Get data from shared storage using key Args ---- key: Str use key to get """ ref = self._storage.get.remote(key) return future.get(ref) def validate(self): """ :meta private: """ def before_episode(self): """ Operations before one episode. """ def after_episode(self): """ Operations after one episode. """ self._episode_id += 1 def build_dataset(self, train_prompts, is_eval=False): """ Build prompt dataset Args ---- train_prompts: [Str] A list of prompt string. Returns ------- torch.utils.data.Dataset Dataset with user-defined collate_fn """ def build_all_dataset(self, train_prompts_list, is_eval=False): """ Build all prompt datasets Args ---- train_prompts_list: List[List[Str]] A list of prompt string lists. Returns ------- List[torch.utils.data.Dataset] A list of Dataset with user-defined collate_fn """ all_datasets = [] for train_prompts in train_prompts_list: all_datasets.append( self.build_dataset(train_prompts, is_eval) ) return all_datasets def _build_dataloader(self, data, sample_per_episode, is_eval=False): """ build and set the dataloader for the model Args: data: a list of string is_eval: set to `True` to build a dataloader for evaluation (default: `False`) :meta private: """ all_datasets = self.build_all_dataset(data, is_eval) # pylint: disable=assignment-from-no-return consumed_samples = 0 data_ratio = self.runtime_args.data_ratio shuffle = self.runtime_args.data_shuffle data_rerank = self.runtime_args.data_rerank if not is_eval: if self.data_ckpt_manager is not None: consumed_samples = self.runtime_args.consumed_samples collate_fn = all_datasets[0].collate_fn if hasattr(all_datasets[0], 'collate_fn') else None drop_last = self.model_args['drop_last'] if 'drop_last' in self.model_args else False dataloader = self.build_dataloader(all_datasets, sample_per_episode=sample_per_episode, collate_fn=collate_fn, is_eval=is_eval, consumed_samples=consumed_samples, data_ratio=data_ratio, shuffle=shuffle, drop_last=drop_last, data_rerank=data_rerank) if is_eval: self._eval_dataloader = dataloader self._eval_data_iter = iter(self._eval_dataloader) else: self._data_iter = iter(dataloader) self._data_iter = cycle(self._data_iter) self._dataloader = dataloader def build_dataloader(self, all_datasets, sample_per_episode, collate_fn=None, is_eval=False, consumed_samples=0, data_ratio=None, shuffle=True, drop_last=False, data_rerank=True): """ build the dataloader for the model Args: all_datasets: a list of torch.utils.data.Dataset objects batch_size: how many samples per batch to load collate_fn: set when loading from an map-style dataset (defulat: `None`) is_eval: set to `True` to build a dataloader for evaluation (default: `False`) consumed_samples: consumed samples (default: `0`) data_ratio: ratio of samples for each dataset (default: `None`) drop_last: whether to drop last samples (default: `False`) :meta private: """ log_rank_0( f"Creating DataLoader... consumed_samples: {consumed_samples}, " f"data_ratio: {data_ratio}", self._logger ) if "num_inference_per_prompt" in self.model_args: num_inference_per_prompt = self.model_args["num_inference_per_prompt"] else: num_inference_per_prompt = 1 vllm_prompt_key = self.model_args["vllm_prompt_key"] \ if "vllm_prompt_key" in self.model_args else "prompt" self._logger.info(f"====Data Rerank: {data_rerank}") if is_eval: batch_sampler = MultiDatasetSampler( dataset_sizes=[len(dataset) for dataset in all_datasets], sample_per_episode=sample_per_episode, shuffle=False, is_eval=True, data_parallel_rank=self.replica_id, data_parallel_size=self._num_replica ) else: batch_sampler = MultiDatasetSampler( dataset_sizes=[len(dataset) for dataset in all_datasets], sample_per_episode=sample_per_episode, data_ratio=data_ratio, consumed_samples=consumed_samples, num_inference_per_prompt=num_inference_per_prompt, shuffle=shuffle, is_eval=False, data_parallel_rank=self.replica_id, data_parallel_size=self._num_replica, drop_last="drop" if drop_last else "cycle", data_rerank=data_rerank ) return RLHFDataLoader( all_datasets, batch_sampler, collate_fn=collate_fn, add_uid=True, data_parallel_rank=self.replica_id, data_parallel_size=self._num_replica, num_inference_per_prompt=num_inference_per_prompt, vllm_prompt_key=vllm_prompt_key ) def reset_eval_data_iter(self): """ :meta private: """ if self._eval_dataloader is not None: self._eval_data_iter = iter(self._eval_dataloader) def next_batch(self, is_eval=False): """ :meta private: """ if is_eval: return next(self._eval_data_iter) else: return next(self._data_iter) @property def num_replica(self): """ :meta private: """ return self._num_replica @property def num_gpu_per_replica(self): """ :meta private: """ return self._num_gpu_per_replica def setup_collective_group(self, rank, world_size, backend, group_name): """ :meta private: """ self._group_names.append(group_name) self._world_size = world_size col.init_collective_group( world_size, rank, backend=backend, group_name=group_name) def broadcast_dummy_tensor_send(self, src_rank, group_name): x = torch.zeros(1, device="cuda") col.broadcast(x, src_rank=src_rank, group_name=group_name) del x def broadcast_dummy_tensor_recv(self, src_rank, group_name): x = torch.zeros(1, device="cuda") col.broadcast(x, src_rank=src_rank, group_name=group_name) del x def _destroy_collective_group(self, group_name): """ :meta private: """ from ray.util.collective.collective import _group_mgr # pylint: disable=import-outside-toplevel rank = col.get_rank(group_name) saved_group: BaseGroup = _group_mgr.get_group_by_name(group_name) saved_comm_keys = [] if isinstance(saved_group, (NCCLGroup, )): saved_comm_keys = list(saved_group._dev_comm_map.keys()) try: col.destroy_collective_group(group_name) except Exception as e: self._logger.warning(f"_destroy_collective_group {group_name} {e}") if isinstance(saved_group, (NCCLGroup, )): for comm_key in saved_comm_keys: group_key = saved_group._generate_group_key(comm_key) from ray.util.collective.const import get_store_name # pylint: disable=import-outside-toplevel store_name = get_store_name(group_key) try: store = ray.get_actor(store_name) if rank == 0: raise RuntimeError(f'{store_name} in group {group_name} should be killed on rank {rank}.') self._logger.debug(f'Kill {store_name} in group {group_name} on rank {rank}') ray.kill(store) except ValueError: ... def destroy_collective_group(self): for group_name in self._group_names: self._destroy_collective_group(group_name) self._group_names = [] def get_local_param_ranks(self): """ :meta private: """ def fuse_lora_layer(self): """ :meta private: """ from chatlearn.models.megatron.lora import fuse_lora_layer # pylint: disable=import-outside-toplevel fuse_lora_layer(self.model) def unfuse_lora_layer(self): """ :meta private: """ from chatlearn.models.megatron.lora import unfuse_lora_layer # pylint: disable=import-outside-toplevel unfuse_lora_layer(self.model) @property def rank(self): """ :meta private: """ return self._rank def get_rank(self): """ :meta private: """ return self.rank def is_last_rank(self): """ Is last rank. """ return True @property def parameters(self): """ :meta private: """ if self._parameters is None: if not isinstance(self.model, list): model = [self.model] else: model = self.model self._parameters = [] for partition in model: for item in partition.parameters(): self._parameters.append(item) return self._parameters @property def named_parameters(self): """ :meta private: """ if self._named_parameters is None: if not isinstance(self.model, list): model = [self.model] else: model = self.model self._named_parameters = {} for partition in model: for item in partition.named_parameters(): self._named_parameters[item[0]] = item[1] return self._named_parameters @property def param_to_name(self): """ :meta private: """ if self._param_to_name is None: if not isinstance(self.model, list): model = [self.model] else: model = self.model self._param_to_name = {} for partition in model: for item in partition.named_parameters(): self._param_to_name[item[1]] = item[0] return self._param_to_name def _set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to_sync=None): if parameters_to_sync is None: parameters_to_sync = defaultdict(list) assert pipe_stage not in parameters_to_sync or len(parameters_to_sync[pipe_stage])==0 params_to_sync_list = [(name, self.named_parameters[name]) for name in trainable_param_names] if self._synchronizer is not None: params_to_sync_list = self._synchronizer.transform_parameters(params_to_sync_list) parameters_to_sync[pipe_stage] = params_to_sync_list return parameters_to_sync def set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to_sync=None): """ :meta private: """ if parameters_to_sync is None: parameters_to_sync = self._parameters_to_sync if pipe_stage not in parameters_to_sync or len(parameters_to_sync[pipe_stage]) == 0: self._set_sync_parameters(trainable_param_names, pipe_stage, parameters_to_sync) def reset_sync_parameters(self, trainable_param_names, pipe_stage=0): self._parameters_to_sync[pipe_stage] = [] self._set_sync_parameters(trainable_param_names, pipe_stage, self._parameters_to_sync) def set_send_parameters(self, trainable_param_names, pipe_stage=0): """ :meta private: """ return self.set_sync_parameters(trainable_param_names, pipe_stage, self._parameters_to_send) def set_recv_parameters(self, to_rank, trainable_param_names, pipe_stage=0): """ :meta private: """ parameters_to_recv = defaultdict(list) self._parameters_to_recv[to_rank] = parameters_to_recv return self.set_sync_parameters(trainable_param_names, pipe_stage, parameters_to_recv) def clear_sync_parameters(self): self._parameters_to_sync = defaultdict(list) def clear_send_recv_parameters(self): self._parameters_to_send = defaultdict(list) self._parameters_to_recv = defaultdict(list) def clear_sync_send_recv_parameters(self): self.clear_sync_parameters() self.clear_send_recv_parameters() def get_parameter_names(self, requires_grad=True): """ :meta private: """ param_to_name = self.param_to_name if requires_grad: return [param_to_name[param] for param in self.parameters if param.requires_grad] else: return [param_to_name[param] for param in self.parameters] def get_parameter_shape(self, pipe_stage=0, parameters_to_sync=None): """ :meta private: """ if parameters_to_sync is None: parameters_to_sync = self._parameters_to_sync parameters_shape = [] for name, param in parameters_to_sync[pipe_stage]: if self._expert_sync_buffer and name in self._expert_sync_buffer and \ self._synchronizer and self._synchronizer.is_parameter_changed: parameters_shape.append((name, self._expert_sync_buffer[name].shape)) else: parameters_shape.append((name, param.shape)) return parameters_shape def get_parameter(self, name): """ :meta private: """ if name not in self.named_parameters: raise Exception(f"parameter {name} not exits") return self.named_parameters[name] def get_parameter_to_sync(self, name, pipe_stage, to_cpu=False, regroup=False): assert pipe_stage in self._parameters_to_sync and len(self._parameters_to_sync[pipe_stage]) > 0 for name0, param in self._parameters_to_sync[pipe_stage]: if name0 == name: if name in self._expert_sync_buffer and self._synchronizer and \ self._synchronizer.is_parameter_changed: param = self._expert_sync_buffer[name] regroup_routed_experts = True else: regroup_routed_experts = False if regroup and self._synchronizer: param = self._synchronizer.regroup_params_to_sync( name, param.data, self._tp_division[name], regroup_routed_experts ) if to_cpu: param = param.cpu() else: param = param.cuda() return param def get_parameter_to_sync_names(self, pipe_stage): return [items[0] for items in self._parameters_to_sync[pipe_stage]] def exist_parameter(self, name): """ :meta private: """ return name in self.named_parameters def parameter_shape(self, name): """ :meta private: """ return self.get_parameter(name).shape def send_recv_parameter(self, rank, group_name, func, pipe_stage=0): """ :meta private: """ tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]] dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) for bucket in dense_buckets: tensor_changed = func is col.recv coalesced_comm_dense(bucket, func, extra_args=(rank, group_name), tensor_changed=tensor_changed) for param in sparse_bucket: func(param, rank, group_name) def alltoall_routed_expert_parameter(self, pipe_stage=0): assert self._synchronizer is not None for name, param in self._parameters_to_sync[pipe_stage]: param, state = self._synchronizer.alltoall_routed_experts( name, param, self.tensor_and_expert_parallel_group() ) if state: self._expert_sync_buffer.pop(name, "Not Found.") self._expert_sync_buffer[name] = param def allgather_routed_expert_parameter(self, group_name, pipe_stage=0): assert self._synchronizer is not None for name, param in self._parameters_to_sync[pipe_stage]: param, state = self._synchronizer.allgather_routed_experts( name, param, group_name, tp_rank=self.tensor_parallel_rank() ) if state: self._expert_sync_buffer.pop(name, "Not Found.") self._expert_sync_buffer[name] = param def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): """ :meta private: """ tensors = [] for name, param in self._parameters_to_sync[pipe_stage]: if self._expert_sync_buffer and name in self._expert_sync_buffer and \ (self._synchronizer and self._synchronizer.is_parameter_changed): tensors.append(self._expert_sync_buffer[name]) else: tensors.append(param.data) assert len(tensors) > 0 dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) tensor_changed = rank != src_rank for bucket in dense_buckets: coalesced_comm_dense(bucket, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) for param in sparse_bucket: col.broadcast(param, src_rank, group_name) def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): """ Arguments: to_rank: receive rank in mapping from trainer to inference model. buffer_rank: index which tensors of sync buffer to be sended in stage2. rank: destination rank in communication group which enumerate receive ranks. src_rank: source rank in communication group. always 0. group_name: communication group name. pipe_stage: pipeline stage. default 0. stage2: bool. whether stage2 or not. default False. Example: trainer_tp = 4, inference_tp = 8. pipeline_size = 1 stage1: [(from_rank, to_rank), ...] = [(0, 8), (1, 10), (2, 12), (3, 14)] stage2: [(from_rank, to_rank), ...] = [(8, 9), (10, 11), (12, 13), (14, 15)] For stage1 pair (0, 8): 1. call broadcast func: (0 -> 0). src_rank: 0, rank: 0. 2. call broadcast func: (0 -> 8). src_rank: 0, rank: 1. After (0, 8), to_rank 8 received tensor slices of 8 and 9. For stage2 pair (8, 9): 1. call broadcast func: (8 -> 8). src_rank: 0, rank: 0. 2. call broadcast func: (8 -> 9). src_rank: 0, rank: 1. In (8 -> 8), we need to send tp_slice of 'to_rank' 9, so set buffer_rank 9 to fetch tensors in sync buffer. """ tensor_changed = rank != src_rank start = time.time() arguments = f"{to_rank}_{buffer_rank}_{rank}_{src_rank}_{group_name}_{pipe_stage}_{stage2}" if stage2: if tensor_changed: parameters_to_sync = self._parameters_to_recv[to_rank] else: parameters_to_sync = self._parameters_to_send else: if rank not in self._sync_dst_rank_to_src_ranks: self._sync_dst_rank_to_src_ranks.update({rank:[src_rank]}) del self._sync_buffer self._sync_buffer = defaultdict(list) else: self._sync_dst_rank_to_src_ranks[rank].append(src_rank) parameters_to_sync = self._parameters_to_sync def tensor_generator(): if stage2 and not tensor_changed and self._sync_buffer:# pylint: disable=too-many-nested-blocks idx = 0 for name, param in parameters_to_sync[pipe_stage]: value = self._sync_buffer[buffer_rank % self.tp_num_mapping][idx].cuda() # restore from cpu self._logger.debug( f"Adding {name}({value.shape}) to sync for if branch from " f"src_rank: {src_rank} to rank: {rank} in pipe_stage {pipe_stage}" ) buffer_num = 1 idx += 1 yield value, buffer_num del self._sync_buffer[buffer_rank % self.tp_num_mapping] else: idx = 0 for name, param in parameters_to_sync[pipe_stage]: idx += 1 param_data = param.data if rank and self._buffer_num and not stage2: assert name in self._buffer_num, f"{name} in self._buffer_num for rank {rank}" buffer_num = self._buffer_num[name] elif stage2: buffer_num = 1 else: if self._expert_sync_buffer and name in self._expert_sync_buffer: param_data = self._expert_sync_buffer[name] regroup_routed_experts = True # For routed experts in Qwen2vLLM else: regroup_routed_experts = False # regroup src_tensor by tp_rank param_data = self._synchronizer.regroup_params_to_sync( name, param_data, self._tp_division[name], regroup_routed_experts ) # move self._expert_sync_buffer[name] to cpu mem to save gpu mem if regroup_routed_experts and name in self._expert_sync_buffer: cpu_expert = self._expert_sync_buffer[name].cpu() del self._expert_sync_buffer[name] self._expert_sync_buffer[name] = cpu_expert buffer_num = 1 self._logger.debug( f"Adding {name}({param_data.shape}) to sync for else branch from " f"src_rank: {src_rank} to rank: {rank} in pipe_stage {pipe_stage}" ) yield param_data, buffer_num bucket_generator = bucket_tensors_two_stage_generator( tensor_generator, bucket_size_mb=self.runtime_args.coalesced_buffer_mb, stage2=stage2, tensor_changed=tensor_changed and not stage2 ) dense_bucket_num = 0 sparse_bucket_num = 0 for bucket_or_tensor, is_dense in bucket_generator: if is_dense: index = 0 if stage2 else (to_rank % self.tp_num_mapping) all_buffers = coalesced_comm_dense_two_stage( bucket_or_tensor, col.broadcast, rank, extra_args=(src_rank, group_name), tensor_changed=tensor_changed, stage2=stage2, index=index) if tensor_changed and not stage2: for key, value in all_buffers.items(): cpu_value = [] for tensor in value: cpu_value.append(tensor.cpu().pin_memory()) # save gpu memory del value self._sync_buffer[key] += cpu_value del all_buffers dense_bucket_num += 1 else: col.broadcast(bucket_or_tensor, src_rank, group_name) sparse_bucket_num += 1 if stage2: self._sync_dst_rank_to_src_ranks = {} self._logger.debug(f"broadcast_parameter_two_stage {arguments} done using {time.time()-start} seconds") debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, sparse_bucket {sparse_bucket_num}", self._logger) def send_parameter(self, dst_rank, group_name, pipe_stage=0): """ :meta private: """ self.send_recv_parameter(dst_rank, group_name, col.send, pipe_stage) def recv_parameter(self, src_rank, group_name, pipe_stage=0): """ :meta private: """ self.send_recv_parameter(src_rank, group_name, col.recv, pipe_stage) def ray_put_parameter(self, group_name, pipe_stage=0): """ :meta private: """ name2ref = {} tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]] dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) debug_rank_0(f"{self.name} Put dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) for bucket_id, bucket in enumerate(dense_buckets): flat_tensors = _flatten_dense_tensors(bucket) flat_tensors_ref = ray.put(flat_tensors) name2ref[group_name + ":dense_bucket_" + str(bucket_id)] = flat_tensors_ref for param_id, param in enumerate(sparse_bucket): param_ref = ray.put(param) name2ref[group_name + ":sparse_bucket_" + str(param_id)] = param_ref return name2ref def ray_get_parameter(self, group_name, name2ref, pipe_stage=0): """ :meta private: """ tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]] dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) debug_rank_0(f"{self.name} Get dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) for bucket_id, bucket in enumerate(dense_buckets): put_ref = name2ref[group_name + ":dense_bucket_" + str(bucket_id)] flat_tensors = ray.get(put_ref) for tensor, synced in zip( bucket, _unflatten_dense_tensors(flat_tensors, bucket)): tensor.copy_(synced) for param_id, param in enumerate(sparse_bucket): put_ref = name2ref[group_name + ":sparse_bucket_" + str(param_id)] param.copy_(ray.get(put_ref)) def pipeline_model_parallel_size(self): """ :meta private: """ return self.module_args.pipeline_model_parallel_size def tensor_model_parallel_size(self): """ :meta private: """ return self.module_args.tensor_model_parallel_size def expert_model_parallel_size(self): """ :meta private: """ return self.module_args.expert_model_parallel_size def num_layers(self): """ :meta private: """ def set_storage(self, storage): """ :meta private: """ self._storage = storage def timers(self, name): """ :meta private: """ if self._timers is None: self._timers = Timers() return self._timers(name) def timer_summary(self, e2e_cost=None): """ :meta private: """ if self._timers: return self._timers.log(return_dict=True, e2e_cost=e2e_cost) def get_and_clear_metrics(self): """ get logging metrics """ if self._metric_list is None or len(self._metric_list) == 0: return self._metric_prefix, {} reduced_metrics = map_reduce_metrics(self._metric_list) self._metric_list = [] return self._metric_prefix, reduced_metrics def add_padding_config(self, key, padding_value=0.0, padding_type="right"): """ Add spectial padding config for certain value. Args ---- key: str The key for data to be padded. padding_value: float Padding value, default is 0. padding_type: str Default right, can be right/left. """ self._padding_config[key] = {"padding_value": padding_value, "padding_type": padding_type} def padding_config(self): """ :meta private: """ return self._padding_config def peak_memory(self): """ :meta private: """ return 0.0 @property def resume_training(self): """ resume training from last checkpoint. """ return self._resume_training def get_address(self): """ Get node address :meta private: """ return self._address def is_master_node(self): """ Whether this node is master node. :meta private: """ return self._is_master_node def set_src_parameter_model(self, src_model): """ src_model that sync parameter to current model :meta private: """ self._src_parameter_model = src_model @property def src_parameter_model(self): """ src_model that sync parameter to current model """ return self._src_parameter_model def offload_optimizer_states(self): """ offload optimizer states """ def onload_optimizer_states(self): """ onload optimizer states """ def offload_main_weights(self): """ offload main weights """ def onload_main_weights(self): """ onload main weights """ def offload_weights(self): """ offload weights """ def onload_weights(self): """ onload weights """ def free_grad_buffers(self): """ free grad buffers and related tensors """ def build_grad_buffers(self): """ build grad buffers and related tensors """ def onload(self): pass def offload(self): pass @property def world_size(self): pass @property def data_parallel_size(self): """ data parallel size :meta private: """ @property def data_parallel_rank(self): """ data parallel rank :meta private: """ def empty_cache(self): """ :meta private: """ def get_data_parallel_rank(self): return self.data_parallel_rank def get_data_parallel_size(self): return self.data_parallel_size def get_pipeline_stage_layer_num(self): pass def get_pipeline_stage_layer_offset(self): return 0 def set_synchronizer(self, synchronizer): self._synchronizer = synchronizer def expert_parallel_rank(self): """ :meta private: """ return 0 def enable_stage_resume(self, is_eval): """ check whether to resume stage outputs. """ if is_eval: return False if self.model_args.get("enable_stage_resume", False): assert self.runtime_args.data_checkpoint_path, \ "data_checkpoint_path must be set for stage resume." return True return False def get_stage_outputs_path(self, iteration): """ get path for stage outputs. """ save_dir = self.runtime_args.data_checkpoint_path save_path = f"{save_dir}/{iteration}/{self.name}_replica_{self.replica_id}.pt" save_path_meta = f"{save_dir}/{iteration}/{self.name}_replica_{self.replica_id}_meta.txt" return save_path, save_path_meta def load_stage_outputs(self, is_eval, iteration): """ load stage outputs for resume. """ outputs = None # only load once for each launching. if self.enable_stage_resume(is_eval) and not self._stage_resume_done: self._stage_resume_done = True save_path, save_path_meta=self.get_stage_outputs_path(iteration) if os.path.exists(save_path) and os.path.exists(save_path_meta): try: with open(save_path_meta, "r", encoding='utf-8') as f: replica_id = int(f.readline()) if replica_id == self.replica_id: outputs = torch.load(save_path) logger.info(f"resume stage outputs for model:{self.name}, path:{save_path}") except ValueError: logger.warning(f"ignore incomplete stage outputs, path:{save_path}") return outputs def save_stage_outputs(self, is_eval, outputs, iteration): """ save stage outputs for resume. """ if self.enable_stage_resume(is_eval): save_path, save_path_meta=self.get_stage_outputs_path(iteration) logger.info(f"Start to save stage outputs:{save_path}") Path(save_path).parent.mkdir(parents=True, exist_ok=True) torch.save(outputs, save_path) # save meta with open(save_path_meta, "w", encoding='utf-8') as f: f.write(f"{self.replica_id}") logger.info(f"Finished to save stage outputs:{save_path}")