toolkits/distributed_checkpoints_convertor/impl/general/m2h_synchronizer.py (499 lines of code) (raw):

import os import torch import json import logging from typing import * from collections import defaultdict from functools import partial from torch import distributed as dist from safetensors.torch import save_file as safe_save_file from huggingface_hub.serialization import split_torch_state_dict_into_shards from megatron.training.checkpointing import load_checkpoint from general.synchronizer import BaseSynchronizer, ParamType class ParamMergeError(ValueError): ... class MG2HFSynchronizer(BaseSynchronizer): def __init__(self, load_dir, model_provider_func=None): super().__init__(load_dir, model_provider_func) if not self.dryrun: load_checkpoint( [self._mgmodel], None, None, checkpointing_context=None, skip_load_to_model_and_opt=False ) self.num_savers = self.args.num_hf_saver self.max_shard_size = self.args.max_shard_size # NOTE: mapping unique global id (0 ~ n-1) to local data self._local_params = dict() # NOTE: mapping global id to global metadata self._tensor_shape = dict() self._tensor_dtype = dict() # sharded shape (rather than hf param shape) # mapping rank to (tp, pp, etp, ep, dp, edp) self._rank_mapping = torch.zeros([self.world_size, 6], dtype=torch.int, device=self.device) self._rank_mapping[self.rank] = torch.Tensor([self.tp_rank, self.pp_rank, self.etp_rank, self.ep_rank, self.dp_rank, self.edp_rank]).to(self.device) dist.all_gather_into_tensor(self._rank_mapping, self._rank_mapping[self.rank]) # define the merge function type for each param self._merge_type: torch.Tensor = torch.zeros([self.hf_size], dtype=torch.int, device=self.device) self._has_param: torch.Tensor = None # self._has_param[param_id].nonzero() ==> ranks that have this param def _copy_impl(self, src_tensor, dst_tensor, param_type: ParamType=ParamType.UNIQUE): param_id = self._hf_params_to_id[dst_tensor] if param_type in [ParamType.MOE_COLUMN, ParamType.MOE_ROW, ParamType.MOE_GATE_UP]: # NOTE: only register on edp_rank 0 if self.edp_rank != 0: return elif param_type == ParamType.UNIQUE: if self.dp_rank != 0 or self.edp_rank != 0: return elif self.dp_rank != 0: return self._local_params[param_id] = src_tensor self._merge_type[param_id] = param_type.value def set_preprocess_state(self): '''Set embedding params.''' self.copy( self._mgmodel.embedding.word_embeddings.weight, self._hfmodel.model.embed_tokens.weight, param_type=ParamType.COLUMN ) def set_postprocess_state(self): '''Set output layer & norm params.''' self.copy(self._mgmodel.decoder.final_layernorm.weight, self._hfmodel.model.norm.weight) if self._mgmodel.share_embeddings_and_output_weights: output_layer_weight = self._mgmodel.shared_embedding_or_output_weight() else: output_layer_weight = self._mgmodel.output_layer.weight self.copy(output_layer_weight, self._hfmodel.lm_head.weight, param_type=ParamType.COLUMN) def set_mla_selfattn_state(self, attn, hf_attn): # NOTE: MLA qkv_bias always False if self.args.q_lora_rank is None: self.copy(attn.linear_q_proj.weight, hf_attn.q_proj.weight, param_type=ParamType.COLUMN) else: self.copy(attn.linear_q_down_proj.weight, hf_attn.q_a_proj.weight, param_type=ParamType.COLUMN) self.copy(attn.linear_q_up_proj.weight, hf_attn.q_b_proj.weight, param_type=ParamType.COLUMN) if self.args.qk_layernorm: self.copy( attn.linear_q_up_proj.layer_norm_weight, hf_attn.q_a_layernorm.weight ) self.copy(attn.linear_kv_down_proj.weight, hf_attn.kv_a_proj_with_mqa.weight, param_type=ParamType.COLUMN) self.copy(attn.linear_kv_up_proj.weight, hf_attn.kv_b_proj.weight, param_type=ParamType.COLUMN) if self.args.qk_layernorm: self.copy( attn.linear_kv_up_proj.layer_norm_weight, hf_attn.kv_a_layernorm.weight ) self.copy( attn.linear_proj.weight, hf_attn.o_proj.weight, param_type=ParamType.ROW ) def set_selfattn_state(self, attn, hf_attn): '''Set self-attention params.''' # Reshape loaded weights. tp = self.tp_size num_heads = self.args.num_attention_heads num_query_groups = (self.args.num_query_groups if self.args.group_query_attention else self.args.num_attention_heads) num_querys_per_group = num_heads // num_query_groups dim = self.args.kv_channels assert num_heads % num_querys_per_group == 0 # copy qk norm if indeed. if self.args.qk_layernorm: self.copy(attn.q_layernorm.weight, hf_attn.q_norm.weight) self.copy(attn.k_layernorm.weight, hf_attn.k_norm.weight) # Copy weights (re-order dimensions for Megatron). attn_proj_weight = attn.linear_qkv.weight.reshape( (num_query_groups // tp, (2 + num_querys_per_group)*dim, -1) ) ( q_proj_weight, k_proj_weight, v_proj_weight ) = torch.split(attn_proj_weight, [num_querys_per_group*dim, dim, dim], dim=1) self.copy(q_proj_weight, hf_attn.q_proj.weight, param_type=ParamType.QKV_W) self.copy(k_proj_weight, hf_attn.k_proj.weight, param_type=ParamType.QKV_W) self.copy(v_proj_weight, hf_attn.v_proj.weight, param_type=ParamType.QKV_W) self.copy( attn.linear_proj.weight, hf_attn.o_proj.weight, param_type=ParamType.ROW ) # Copy bias if self.args.add_qkv_bias: attn_proj_bias = attn.linear_qkv.bias.reshape( (num_query_groups // tp, (2 + num_querys_per_group)*dim, -1) ) q_proj_bias, k_proj_bias, v_proj_bias = torch.split( attn_proj_bias, [num_querys_per_group*dim, dim, dim], dim=1 ) self.copy(q_proj_bias, hf_attn.q_proj.bias, param_type=ParamType.QKV_B) self.copy(k_proj_bias, hf_attn.k_proj.bias, param_type=ParamType.QKV_B) self.copy(v_proj_bias, hf_attn.v_proj.bias, param_type=ParamType.QKV_B) def set_mlp_state(self, mlp, hf_mlp): '''Set MLP params.''' hidden_size = mlp.linear_fc1.weight.shape[-1] gate_proj_weight, up_proj_weight = mlp.linear_fc1.weight.reshape(2, -1, hidden_size) self.copy( gate_proj_weight, hf_mlp.gate_proj.weight, param_type=ParamType.COLUMN ) self.copy( up_proj_weight, hf_mlp.up_proj.weight, param_type=ParamType.COLUMN ) self.copy( mlp.linear_fc2.weight, hf_mlp.down_proj.weight, param_type=ParamType.ROW ) def set_sequential_mlp_state(self, experts, hf_experts): '''Set MOE MLP params.''' experts = experts.local_experts for mg_expert_id, hf_expert_id in self._build_expert_parallel_mapping().items(): hidden_size = experts[mg_expert_id].linear_fc1.weight.shape[-1] ( gate_proj_weight, up_proj_weight ) = experts[mg_expert_id].linear_fc1.weight.reshape(2, -1, hidden_size) self.copy( gate_proj_weight, hf_experts[hf_expert_id].gate_proj.weight, param_type=ParamType.MOE_COLUMN ) self.copy( up_proj_weight, hf_experts[hf_expert_id].up_proj.weight, param_type=ParamType.MOE_COLUMN ) self.copy( experts[mg_expert_id].linear_fc2.weight, hf_experts[hf_expert_id].down_proj.weight, param_type=ParamType.MOE_ROW ) def set_group_mlp_state(self, experts, hf_experts): for mg_expert_id, hf_expert_id in self._build_expert_parallel_mapping().items(): hidden_size = getattr(experts.linear_fc1, f'weight{mg_expert_id}').shape[-1] ( gate_proj_weight, up_proj_weight ) = getattr(experts.linear_fc1, f'weight{mg_expert_id}').reshape(2, -1, hidden_size) self.copy( gate_proj_weight, hf_experts[hf_expert_id].gate_proj.weight, param_type=ParamType.MOE_COLUMN ) self.copy( up_proj_weight, hf_experts[hf_expert_id].up_proj.weight, param_type=ParamType.MOE_COLUMN ) self.copy( getattr(experts.linear_fc2, f'weight{mg_expert_id}'), hf_experts[hf_expert_id].down_proj.weight, param_type=ParamType.MOE_ROW ) def set_moe_layer_state(self, moe, hf_moe): # router self.copy(moe.router.weight, hf_moe.gate.weight) if moe.router.enable_expert_bias: self.copy(moe.router.expert_bias, hf_moe.gate.e_score_correction_bias) # experts if self.args.moe_grouped_gemm: # group gemm if self.args.moe_use_legacy_grouped_gemm: # weight1 and weight2, not impl raise NotImplementedError("Currently only TE GroupGEMM is implemented.") self.set_group_mlp_state(moe.experts, hf_moe.experts) else: # sequential self.set_sequential_mlp_state(moe.experts, hf_moe.experts) # shared experts if moe.shared_experts is not None: if moe.shared_experts.use_shared_expert_gate: self.copy(moe.shared_experts.gate_weight, hf_moe.shared_expert_gate.weight) hidden_size = moe.shared_experts.linear_fc1.weight.shape[-1] gate_proj_weight, up_proj_weight = moe.shared_experts.linear_fc1.weight.reshape(2, -1, hidden_size) self.copy( gate_proj_weight, hf_moe.shared_experts.gate_proj.weight, param_type=ParamType.COLUMN ) self.copy( up_proj_weight, hf_moe.shared_experts.up_proj.weight, param_type=ParamType.COLUMN ) self.copy( moe.shared_experts.linear_fc2.weight, hf_moe.shared_experts.down_proj.weight, param_type=ParamType.ROW ) def set_layer_state(self, layer, hf_layer): '''Set transformer layer params.''' if self.args.multi_latent_attention: self.set_mla_selfattn_state(layer.self_attention, hf_layer.self_attn) self.copy(layer.input_layernorm.weight, hf_layer.input_layernorm.weight) else: self.set_selfattn_state(layer.self_attention, hf_layer.self_attn) self.copy(layer.self_attention.linear_qkv.layer_norm_weight, hf_layer.input_layernorm.weight) if hasattr(layer.mlp, 'router'): self.set_moe_layer_state(layer.mlp, hf_layer.mlp) self.copy(layer.pre_mlp_layernorm.weight, hf_layer.post_attention_layernorm.weight) else: self.set_mlp_state(layer.mlp, hf_layer.mlp) self.copy(layer.mlp.linear_fc1.layer_norm_weight, hf_layer.post_attention_layernorm.weight) def check_and_save(self, output_dir): sharded_info = split_torch_state_dict_into_shards( self._hfmodel.state_dict(), max_shard_size=self.max_shard_size ) global_shape = {self._hf_params_key_to_id[k]: v.shape for k, v in self._hfmodel.state_dict().items()} # select local bucket(s) for each rank n_savers = self.num_savers local_buckets = [] max_n_local_buckets = 1 if not sharded_info.is_sharded and self.rank == 0: local_buckets = list(sharded_info.filename_to_tensors.keys()) else: n_buckets = len(sharded_info.filename_to_tensors) rank = self.rank n_local_buckets = n_buckets // n_savers remainder = n_buckets % n_savers max_n_local_buckets = n_local_buckets + 1 if remainder == 0: start = rank * n_local_buckets max_n_local_buckets -= 1 elif rank < remainder: n_local_buckets += 1 start = rank * n_local_buckets else: start = rank * n_local_buckets + remainder local_buckets = list(sharded_info.filename_to_tensors.keys())[start:start + n_local_buckets] if rank == 0: index = { "metadata": sharded_info.metadata, "weight_map": sharded_info.tensor_to_filename, } os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "model.safetensors.index.json"), "w") as f: f.write(json.dumps(index, indent=2)) self._collect_dist_info() # In each iteration, all ranks save at most one local bucket for bucket_idx in range(max_n_local_buckets): required_keys = [] if bucket_idx < len(local_buckets): bucket_name = local_buckets[bucket_idx] required_keys: List[str] = sharded_info.filename_to_tensors[bucket_name] # build send/recv op across all ranks data, buffers, send_param_ids, recv_param_ids, ops = self._build_p2p_ops(required_keys) # run data sync if self.debug: logging.info(f"[Iters {bucket_idx} RANK {self.rank}] starts synchronizing parameters with other ranks...") if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) if self.debug: for op in ops: if op.op == dist.isend: logging.info(f"[Iters {bucket_idx} RANK {self.rank}] ({self.rank} -> {op.peer}) with {op.tensor.numel() * op.tensor.dtype.itemsize / 2 ** 20} MiB.") else: logging.info(f"[Iters {bucket_idx} RANK {self.rank}] ({op.peer} -> {self.rank}) with {op.tensor.numel() * op.tensor.dtype.itemsize / 2 ** 20} MiB.") for req in reqs: req.wait() if self.debug: logging.info(f"[Iters {bucket_idx} RANK {self.rank}] finishes synchronizing") for remote_rank, param_ids in recv_param_ids.items(): for param_id, tensor in zip(param_ids, self._unpack_from_buffer(buffers[remote_rank], param_ids)): data[param_id][remote_rank] = tensor # apply merge function on the results # data: Dict[param_id, Dict[rank_id, tensor]] output_data = dict() for param_id, data_dict in data.items(): key = self._id_to_hf_params_key[param_id] param_type = ParamType(int(self._merge_type[param_id])) if param_type == ParamType.NULL: raise ValueError(f"ParamType.NULL found on {key}.") try: output_data[key] = self._merge_data(param_type, data_dict) except ParamMergeError as e: raise ValueError(f"Merge Error on key {key}: {e}") if output_data[key].shape != global_shape[param_id]: raise ValueError(f"Unexpected shape on {key}. Expected: {global_shape[param_id]}, but {output_data[key].shape}") del data # save safetensor files if bucket_idx < len(local_buckets): if not self.dryrun: safe_save_file( output_data, os.path.join(output_dir, local_buckets[bucket_idx]), metadata={"format": "pt"}, ) print(f"[Iters {bucket_idx} RANK {self.rank}] {local_buckets[bucket_idx]} is saved.") if self.debug: logging.debug(f"[Iters {bucket_idx} RANK {self.rank}] joined") dist.barrier() def _collect_dist_info(self): # Collect following metadatas: # param_id --> source_rank self._has_param = torch.zeros( [self.world_size, len(self._hf_params_key_to_id)], dtype=torch.bool, device=self.device ) for param_id in self._local_params.keys(): self._has_param[self.rank][param_id] = True dist.all_gather_into_tensor(self._has_param, self._has_param[self.rank]) self._has_param = self._has_param.T # param_id --> tensor_shape Dict[int, Tuple[int, ...]] # param_id --> tensor_dtype Dict[int, dtype] for param_id, param in self._local_params.items(): self._tensor_shape[param_id] = param.shape self._tensor_dtype[param_id] = param.dtype # collect across ranks tensor_shapes = [None] * self.world_size tensor_dtypes = [None] * self.world_size dist.all_gather_object(tensor_shapes, self._tensor_shape) dist.all_gather_object(tensor_dtypes, self._tensor_dtype) # NOTE: merge them together for rank, remote_shape in enumerate(tensor_shapes): if rank == self.rank: continue for remote_key, shape in remote_shape.items(): if remote_key not in self._tensor_shape: self._tensor_shape[remote_key] = shape elif shape != self._tensor_shape[remote_key]: raise ValueError( f"Find mismatched shape on local rank {self.rank} and remote rank {rank}, local shape: {self._tensor_shape[remote_key]}; remote shape: {shape}" ) for rank, remote_dtype in enumerate(tensor_dtypes): if rank == self.rank: continue for remote_key, dtype in remote_dtype.items(): if remote_key not in self._tensor_dtype: self._tensor_dtype[remote_key] = dtype elif dtype != self._tensor_dtype[remote_key]: raise ValueError( f"Find mismatched dtype on local rank {self.rank} and remote rank {rank}, local shape: {self._tensor_dtype[remote_key]}; remote shape: {dtype}" ) # merge_type global_merge_type = torch.zeros([self.world_size, self.hf_size], dtype=self._merge_type.dtype, device=self.device) dist.all_gather_into_tensor(global_merge_type, self._merge_type) for remote_rank_id, remote_merge_type in enumerate(global_merge_type): if self.debug: and_mask = torch.logical_and(remote_merge_type > 0, self._merge_type > 0) if (self._merge_type[and_mask] != remote_merge_type[and_mask]).any(): param_id = -1 for param_id in range(self.hf_size): if ( self._merge_type[param_id] > 0 and remote_merge_type[param_id] > 0 and self._merge_type[param_id] != remote_merge_type[param_id] ): break key = self._id_to_hf_params_key[param_id] raise ValueError(f"Find mismatched merge_type between local rank {self.rank} and remote rank {remote_rank_id} on key {key}") self._merge_type[remote_merge_type > 0] = remote_merge_type[remote_merge_type > 0] def _unpack_from_buffer(self, buffer: torch.Tensor, param_ids: List[int]) -> List[torch.Tensor]: start = 0 datas = [] for param_id in param_ids: shape = self._tensor_shape[param_id] dtype = self._tensor_dtype[param_id] offset = shape.numel() * dtype.itemsize datas.append(buffer[start:start + offset].view(dtype).view(shape).clone()) start += offset if start != buffer.numel(): raise ValueError(f"Expect {start} bytes from remote, but got {buffer.numel()} bytes!") return datas def _pack_into_byte_buffer(self, tensors: List[torch.Tensor]) -> torch.Tensor: return torch.cat([t.flatten().view(torch.uint8) for t in tensors]) def _build_p2p_ops(self, required_keys: List[str]): required_ids = torch.zeros( [self.world_size, self.hf_size], dtype=torch.bool, device=self.device ) for k in required_keys: required_ids[self.rank][self._hf_params_key_to_id[k]] = True dist.all_gather_into_tensor(required_ids, required_ids[self.rank]) send_ops = [] if self.debug: # (param_id, src_rank, dst_rank) send_recv_pattern = torch.zeros([self.hf_size, self.world_size, self.world_size], dtype=torch.int, device=self.device) send_param_ids = defaultdict(list) datas = defaultdict(list) # NOTE: for rank i to rank j, send params by ascending order # NOTE: to avoid hangs observed in multi-nodes scenarios, we merge tensors sent to same remote rank into a large uint8 tensor for param_id, has_data in enumerate(self._has_param[:, self.rank]): if not has_data: continue # group by remote_id for remote_rank, should_send in enumerate(required_ids[:, param_id]): if not should_send or remote_rank == self.rank: continue # NOTE: for each receiver, send param in ascending order by id data = self._local_params[param_id] if data.device != self.device: logging.warning(f"Find unexpected device {data.device} on key {self._id_to_hf_params_key[param_id]}, moving to {self.device}") data = data.to(self.device) if data.dtype != self._tensor_dtype[param_id]: raise ValueError(f"Get mismatched data type on key {self._id_to_hf_params_key[param_id]}") datas[remote_rank].append(data) send_param_ids[remote_rank].append(param_id) if self.debug: send_recv_pattern[param_id, self.rank, remote_rank] += 1 for remote_rank, raw_data in datas.items(): if len(raw_data) > 0: send_ops.append(dist.P2POp( dist.isend, self._pack_into_byte_buffer(raw_data), peer=remote_rank, tag=self.rank * self.world_size + remote_rank # (sender_rank, receiver_rank) ignored in NCCL )) recv_ops = [] collected_data = defaultdict(dict) buffer_size = [0] * self.world_size recv_param_ids = defaultdict(list) # NOTE: for rank i to rank j, recv params by ascending order for param_id, is_required in enumerate(required_ids[self.rank]): if not is_required: continue for remote_rank, has_data in enumerate(self._has_param[param_id]): if not has_data: continue if remote_rank == self.rank: collected_data[param_id][remote_rank] = self._local_params[param_id] else: recv_param_ids[remote_rank].append(param_id) shape = self._tensor_shape[param_id] dtype = self._tensor_dtype[param_id] buffer_size[remote_rank] += shape.numel() * dtype.itemsize if self.debug: send_recv_pattern[param_id, remote_rank, self.rank] -= 1 buffers = [None] * self.world_size for remote_rank, rank_size in enumerate(buffer_size): if rank_size == 0: continue buffers[remote_rank] = torch.empty(rank_size, dtype=torch.uint8, device=self.device) recv_ops.append(dist.P2POp( dist.irecv, buffers[remote_rank], peer=remote_rank, tag=remote_rank * self.world_size + self.rank # (sender_rank, receiver_rank) ignored in NCCL )) if self.debug: dist.all_reduce(send_recv_pattern) if send_recv_pattern.sum() != 0: for param_id, pattern_per_param in enumerate(send_recv_pattern): if pattern_per_param.sum() != 0: raise ValueError(f"Mismatched send/recv ops detected on key {self._id_to_hf_params_key[param_id]}: {pattern_per_param}.") raise ValueError("Mismatched send/recv ops detected.") logging.debug(f"[RANK {self.rank}] {len(send_ops)} send op & {len(recv_ops)} recv op.") return collected_data, buffers, send_param_ids, recv_param_ids, (send_ops + recv_ops) def _merge_data(self, merge_type: ParamType, tensor_dict) -> torch.Tensor: """Merge ShardedTensor collected across the group on this rank. If DP or EDP is larger than 1, tensor_dict could contains multiple data copies and an data deduplication isrequired. Args: merge_type (ParamType): Themerge policy be applied on the ShardedTensor tensor_dict (Dict[int, torch.Tensor]): The collected dict of ShardedTensors, mapping global_rank to tensor data. Returns: torch.Tensor: The merged tensor """ # tensor_dict: Dict[remote_rank, torch.Tensor] def merge_along_axis(axis, tensor_dict, is_expert: bool = False): global_ranks = torch.tensor(list(tensor_dict.keys()), dtype=torch.long, device=self.device) # (N, ) ranks = self._rank_mapping.index_select(0, global_ranks) # (N, 6) if self.debug: if is_expert and ranks[:, 5].any(): raise ParamMergeError("Unexpected expert parameter data from non-zero expert dp rank") if not is_expert and ranks[:, 4].any(): raise ParamMergeError("Unexpected parameter data from non-zero dp rank") if ranks[:, 1].max() != ranks[:, 1].min(): raise ParamMergeError("Unexpected parameter data from multiple pp ranks") if is_expert and ranks[:, 3].max() != ranks[:, 3].min(): raise ParamMergeError("Unexpected expert parameter data from multiple ep ranks") def deduplicate_and_sort(tensor_dict, rank_group: int): tensors = dict() for global_rank, data in tensor_dict.items(): rank = self._rank_mapping[global_rank][rank_group] tensors[int(rank)] = data return [item[1] for item in sorted(tensors.items())] tensors = list(tensor_dict.values())[:1] if is_expert: tensors = deduplicate_and_sort(tensor_dict, 2) else: tensors = deduplicate_and_sort(tensor_dict, 0) return torch.cat(tensors, dim=axis) def no_merge_func(tensor_dict): return list(tensor_dict.values())[0] def merge_qkv(is_bias, tensor_dict): res = merge_along_axis(0, tensor_dict) if is_bias: assert res.shape[-1] == 1 return res.flatten() return res.flatten(0, 1) merge_func_mapping = { ParamType.MOE_COLUMN: partial(merge_along_axis, 0, is_expert=True), ParamType.MOE_ROW: partial(merge_along_axis, 1, is_expert=True), ParamType.COLUMN: partial(merge_along_axis, 0), ParamType.ROW: partial(merge_along_axis, 1), ParamType.QKV_W: partial(merge_qkv, False), ParamType.QKV_B: partial(merge_qkv, True), ParamType.UNIQUE: no_merge_func, } return merge_func_mapping[merge_type](tensor_dict)