toolkits/distributed_checkpoints_convertor/impl/general/h2m_synchronizer.py (259 lines of code) (raw):

import os import shutil import torch import json import logging from torch import distributed as dist from safetensors import safe_open from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME from huggingface_hub.constants import SAFETENSORS_INDEX_FILE from megatron.training.checkpointing import ( save_checkpoint, get_checkpoint_tracker_filename, get_checkpoint_name ) from general.synchronizer import BaseSynchronizer, ParamType class HF2MGSynchronizer(BaseSynchronizer): def __init__(self, load_dir, model_provider_func=None): super().__init__(load_dir, model_provider_func) self._single_file = False p = os.path.join(self.load_dir, SAFE_WEIGHTS_INDEX_NAME) if not os.path.exists(p): p = os.path.join(self.load_dir, SAFETENSORS_INDEX_FILE) if os.path.exists(p): with open(p, 'r') as f: data = json.load(f)['weight_map'] self._key_to_file = {k: os.path.join(self.load_dir, v) for k, v in data.items()} elif os.path.exists(os.path.join(self.load_dir, 'model.safetensors')): self._single_file = True self._key_to_file = dict() else: raise FileNotFoundError() if self.debug: if not self.dryrun: for p in self._mgmodel.parameters(): p.data.fill_(torch.nan) # NOTE: Fill non-persistent/persistent buffer with NaN for b in self._mgmodel.buffers(): b.data.fill_(torch.nan) self._visit = torch.zeros([self.hf_size], dtype=torch.int, device=self.device) def load_tensor(self, dummy_tensor): def _get_filename_from_key(key): if self._single_file: return os.path.join(self.load_dir, 'model.safetensors') if key in self._key_to_file: return self._key_to_file[key] raise KeyError(f'{key} not found in index file') if dummy_tensor not in self._hf_params_to_key: raise ValueError() key = self._hf_params_to_key[dummy_tensor] if self.debug: self._visit[self._hf_params_key_to_id[key]] = True if not self.args.untie_embeddings_and_output_weights and key == 'lm_head.weight': key = 'model.embed_tokens.weight' file = _get_filename_from_key(key) with safe_open(file, framework="pt", device=str(self.device)) as f: return f.get_tensor(key) def _copy_impl( self, src_tensor, dst_tensor, param_type: ParamType=ParamType.UNIQUE ): tp_rank, tp_size = self.tp_rank, self.tp_size if param_type in [ParamType.MOE_COLUMN, ParamType.MOE_ROW, ParamType.MOE_GATE_UP]: tp_rank, tp_size = self.etp_rank, self.etp_size split_mapping = { ParamType.UNIQUE: lambda x: self.load_tensor(x), ParamType.COLUMN: lambda x: torch.chunk(self.load_tensor(x), tp_size, dim=0)[tp_rank], ParamType.ROW: lambda x: torch.chunk(self.load_tensor(x), tp_size, dim=1)[tp_rank], # the data of following type is loaded by caller ParamType.GATE_UP: lambda x: torch.chunk(x, tp_size, dim=1)[tp_rank].flatten(0, 1), ParamType.QKV_W: lambda x: torch.chunk(x, tp_size, dim=0)[tp_rank].flatten(0, 1), ParamType.QKV_B: lambda x: torch.chunk(x, tp_size, dim=0)[tp_rank].flatten(), ParamType.MOE_COLUMN: lambda x: torch.chunk(self.load_tensor(x), tp_size, dim=0)[tp_rank], ParamType.MOE_ROW: lambda x: torch.chunk(self.load_tensor(x), tp_size, dim=1)[tp_rank], # the data of following type is loaded by caller ParamType.MOE_GATE_UP: lambda x: torch.chunk(x, tp_size, dim=1)[tp_rank].flatten(0, 1), } if self.dryrun: return dst_tensor.data.copy_(dst_tensor.clone()) dst_tensor.data.copy_(split_mapping[param_type](src_tensor)) def set_preprocess_state(self): '''Set embedding params.''' self.copy( self._hfmodel.model.embed_tokens.weight, self._mgmodel.embedding.word_embeddings.weight, param_type=ParamType.COLUMN ) def set_postprocess_state(self): '''Set output layer & norm params.''' self.copy( self._hfmodel.model.norm.weight, self._mgmodel.decoder.final_layernorm.weight, ) if self._mgmodel.share_embeddings_and_output_weights: output_layer_weight = self._mgmodel.shared_embedding_or_output_weight() else: output_layer_weight = self._mgmodel.output_layer.weight self.copy( self._hfmodel.lm_head.weight, output_layer_weight, param_type=ParamType.COLUMN ) def set_mla_selfattn_state(self, attn, hf_attn): # NOTE: MLA qkv_bias always False if self.args.q_lora_rank is None: self.copy(hf_attn.q_proj.weight, attn.linear_q_proj.weight, param_type=ParamType.COLUMN) else: self.copy(hf_attn.q_a_proj.weight, attn.linear_q_down_proj.weight, param_type=ParamType.COLUMN) self.copy(hf_attn.q_b_proj.weight, attn.linear_q_up_proj.weight, param_type=ParamType.COLUMN) if self.args.qk_layernorm: self.copy( hf_attn.q_a_layernorm.weight, attn.linear_q_up_proj.layer_norm_weight ) self.copy(hf_attn.kv_a_proj_with_mqa.weight, attn.linear_kv_down_proj.weight, param_type=ParamType.COLUMN) self.copy(hf_attn.kv_b_proj.weight, attn.linear_kv_up_proj.weight, param_type=ParamType.COLUMN) if self.args.qk_layernorm: self.copy( hf_attn.kv_a_layernorm.weight, attn.linear_kv_up_proj.layer_norm_weight ) self.copy( hf_attn.o_proj.weight, attn.linear_proj.weight, param_type=ParamType.ROW ) def set_selfattn_state(self, attn, hf_attn): '''Set self-attention params.''' # Reshape loaded weights. num_heads = self.args.num_attention_heads num_query_groups = (self.args.num_query_groups if self.args.group_query_attention else self.args.num_attention_heads) num_querys_per_group = num_heads // num_query_groups dim = self.args.kv_channels assert num_heads % num_querys_per_group == 0 # copy qk norm if indeed. if self.args.qk_layernorm: self.copy(hf_attn.q_norm.weight, attn.q_layernorm.weight) self.copy(hf_attn.k_norm.weight, attn.k_layernorm.weight) # Copy weights (re-order dimensions for Megatron). if self.dryrun: attn_proj_weight = attn.linear_qkv.weight else: attn_proj_weight = torch.cat([ self.load_tensor(hf_attn.q_proj.weight).reshape((num_query_groups, num_querys_per_group*dim, -1)), self.load_tensor(hf_attn.k_proj.weight).reshape((num_query_groups, dim, -1)), self.load_tensor(hf_attn.v_proj.weight).reshape((num_query_groups, dim, -1)), ], dim=1) self.copy( attn_proj_weight, attn.linear_qkv.weight, param_type=ParamType.QKV_W, ) self.copy( hf_attn.o_proj.weight, attn.linear_proj.weight, param_type=ParamType.ROW ) # Copy bias if self.args.add_qkv_bias: if self.dryrun: attn_proj_bias = attn.linear_qkv.bias else: attn_proj_bias = torch.cat([ self.load_tensor(hf_attn.q_proj.bias).reshape((num_query_groups, num_querys_per_group*dim, -1)), self.load_tensor(hf_attn.k_proj.bias).reshape((num_query_groups, dim, -1)), self.load_tensor(hf_attn.v_proj.bias).reshape((num_query_groups, dim, -1)), ], dim=1) self.copy( attn_proj_bias, attn.linear_qkv.bias, param_type=ParamType.QKV_B, ) def set_mlp_state(self, mlp, hf_mlp, expert_id=''): ''' Set MLP params. The mlp (mcore MLP) should have attributes `linear_fc1` and `linear_fc2`. Currently only Gated Linear is supported. ''' if self.dryrun: gate_up_proj_weight = mlp.linear_fc1.weight else: gate_up_proj_weight = torch.stack([ self.load_tensor(hf_mlp.gate_proj.weight), self.load_tensor(hf_mlp.up_proj.weight) ]) linear_fc1_weight = getattr(mlp.linear_fc1, f'weight{expert_id}') linear_fc2_weight = getattr(mlp.linear_fc2, f'weight{expert_id}') self.copy( gate_up_proj_weight, linear_fc1_weight, param_type=ParamType.GATE_UP if expert_id == '' else ParamType.MOE_GATE_UP ) self.copy( hf_mlp.down_proj.weight, linear_fc2_weight, param_type=ParamType.ROW if expert_id == '' else ParamType.MOE_ROW ) def set_sequential_mlp_state(self, experts, hf_experts): '''Set MOE MLP params.''' experts = experts.local_experts for mg_expert_id, hf_expert_id in self._build_expert_parallel_mapping().items(): self.set_mlp_state(experts[mg_expert_id], hf_experts[hf_expert_id]) def set_group_mlp_state(self, experts, hf_experts): for mg_expert_id, hf_expert_id in self._build_expert_parallel_mapping().items(): self.set_mlp_state(experts, hf_experts[hf_expert_id], expert_id=mg_expert_id) def set_moe_layer_state(self, moe, hf_moe): # router self.copy(hf_moe.gate.weight, moe.router.weight) if moe.router.enable_expert_bias: self.copy(hf_moe.gate.e_score_correction_bias, moe.router.expert_bias) # experts if self.args.moe_grouped_gemm: # group gemm if self.args.moe_use_legacy_grouped_gemm: # weight1 and weight2, not impl raise NotImplementedError("Currently only TE GroupGEMM is implemented.") self.set_group_mlp_state(moe.experts, hf_moe.experts) else: # sequential self.set_sequential_mlp_state(moe.experts, hf_moe.experts) # shared experts if moe.shared_experts is not None: if moe.shared_experts.use_shared_expert_gate: self.copy(hf_moe.shared_expert_gate.weight, moe.shared_experts.gate_weight) self.set_mlp_state(moe.shared_experts, hf_moe.shared_experts) def set_layer_state(self, layer, hf_layer): '''Set transformer layer params.''' if self.args.multi_latent_attention: self.set_mla_selfattn_state(layer.self_attention, hf_layer.self_attn) self.copy(hf_layer.input_layernorm.weight, layer.input_layernorm.weight) else: self.set_selfattn_state(layer.self_attention, hf_layer.self_attn) self.copy(hf_layer.input_layernorm.weight, layer.self_attention.linear_qkv.layer_norm_weight) if hasattr(layer.mlp, 'router'): self.set_moe_layer_state(layer.mlp, hf_layer.mlp) self.copy(hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight) else: self.set_mlp_state(layer.mlp, hf_layer.mlp) self.copy(hf_layer.post_attention_layernorm.weight, layer.mlp.linear_fc1.layer_norm_weight) def check_and_save(self, output_dir): if self.debug: if not self.dryrun: for n, p in self._mgmodel.state_dict().items(): if isinstance(p, torch.Tensor) and p.isnan().any(): raise SystemError(f'NaN Parameters Detected on key {n}') from torch.distributed import all_reduce all_reduce(self._visit) unvisit_param_ids = (self._visit == 0).nonzero()[:, 0].cpu().numpy().tolist() unvisit_keys = [] for param_id in unvisit_param_ids: unvisit_keys.append(self._id_to_hf_params_key[param_id]) if len(unvisit_keys) > 0: logging.warning(f"Never visit the following huggingface weights in the conversion: {unvisit_keys}") self.args.save = output_dir if not self.dryrun: save_checkpoint( getattr(self.args, 'iteration', 1), [self._mgmodel], None, None, 0, pipeline_rank=self.pp_rank, pipeline_parallel=self.pp_size > 1, expert_rank=self.ep_rank, expert_parallel=self.ep_size > 1, tensor_rank=self.tp_rank ) dist.barrier() if self.rank == 0: # NOTE: The `save_checkpoint` API can only save a checkpoint in release=False, # reset the metadata. (Otherwise user may find their training starts at step 2) tracker_filename = get_checkpoint_tracker_filename(self.args.save) with open(tracker_filename, 'w') as f: f.write('release') source_dir = get_checkpoint_name(self.args.save, 1, False, return_base_dir=True) target_dir = get_checkpoint_name(self.args.save, -1, True, return_base_dir=True) shutil.move(source_dir, target_dir)