toolkits/distributed_checkpoints_convertor/impl/general/m2h_synchronizer.py (499 lines of code) (raw):
import os
import torch
import json
import logging
from typing import *
from collections import defaultdict
from functools import partial
from torch import distributed as dist
from safetensors.torch import save_file as safe_save_file
from huggingface_hub.serialization import split_torch_state_dict_into_shards
from megatron.training.checkpointing import load_checkpoint
from general.synchronizer import BaseSynchronizer, ParamType
class ParamMergeError(ValueError):
...
class MG2HFSynchronizer(BaseSynchronizer):
def __init__(self, load_dir, model_provider_func=None):
super().__init__(load_dir, model_provider_func)
if not self.dryrun:
load_checkpoint(
[self._mgmodel],
None,
None,
checkpointing_context=None,
skip_load_to_model_and_opt=False
)
self.num_savers = self.args.num_hf_saver
self.max_shard_size = self.args.max_shard_size
# NOTE: mapping unique global id (0 ~ n-1) to local data
self._local_params = dict()
# NOTE: mapping global id to global metadata
self._tensor_shape = dict()
self._tensor_dtype = dict() # sharded shape (rather than hf param shape)
# mapping rank to (tp, pp, etp, ep, dp, edp)
self._rank_mapping = torch.zeros([self.world_size, 6], dtype=torch.int, device=self.device)
self._rank_mapping[self.rank] = torch.Tensor([self.tp_rank, self.pp_rank, self.etp_rank, self.ep_rank, self.dp_rank, self.edp_rank]).to(self.device)
dist.all_gather_into_tensor(self._rank_mapping, self._rank_mapping[self.rank])
# define the merge function type for each param
self._merge_type: torch.Tensor = torch.zeros([self.hf_size], dtype=torch.int, device=self.device)
self._has_param: torch.Tensor = None # self._has_param[param_id].nonzero() ==> ranks that have this param
def _copy_impl(self, src_tensor, dst_tensor, param_type: ParamType=ParamType.UNIQUE):
param_id = self._hf_params_to_id[dst_tensor]
if param_type in [ParamType.MOE_COLUMN, ParamType.MOE_ROW, ParamType.MOE_GATE_UP]:
# NOTE: only register on edp_rank 0
if self.edp_rank != 0:
return
elif param_type == ParamType.UNIQUE:
if self.dp_rank != 0 or self.edp_rank != 0:
return
elif self.dp_rank != 0:
return
self._local_params[param_id] = src_tensor
self._merge_type[param_id] = param_type.value
def set_preprocess_state(self):
'''Set embedding params.'''
self.copy(
self._mgmodel.embedding.word_embeddings.weight,
self._hfmodel.model.embed_tokens.weight,
param_type=ParamType.COLUMN
)
def set_postprocess_state(self):
'''Set output layer & norm params.'''
self.copy(self._mgmodel.decoder.final_layernorm.weight, self._hfmodel.model.norm.weight)
if self._mgmodel.share_embeddings_and_output_weights:
output_layer_weight = self._mgmodel.shared_embedding_or_output_weight()
else:
output_layer_weight = self._mgmodel.output_layer.weight
self.copy(output_layer_weight, self._hfmodel.lm_head.weight, param_type=ParamType.COLUMN)
def set_mla_selfattn_state(self, attn, hf_attn):
# NOTE: MLA qkv_bias always False
if self.args.q_lora_rank is None:
self.copy(attn.linear_q_proj.weight, hf_attn.q_proj.weight, param_type=ParamType.COLUMN)
else:
self.copy(attn.linear_q_down_proj.weight, hf_attn.q_a_proj.weight, param_type=ParamType.COLUMN)
self.copy(attn.linear_q_up_proj.weight, hf_attn.q_b_proj.weight, param_type=ParamType.COLUMN)
if self.args.qk_layernorm:
self.copy(
attn.linear_q_up_proj.layer_norm_weight,
hf_attn.q_a_layernorm.weight
)
self.copy(attn.linear_kv_down_proj.weight, hf_attn.kv_a_proj_with_mqa.weight, param_type=ParamType.COLUMN)
self.copy(attn.linear_kv_up_proj.weight, hf_attn.kv_b_proj.weight, param_type=ParamType.COLUMN)
if self.args.qk_layernorm:
self.copy(
attn.linear_kv_up_proj.layer_norm_weight,
hf_attn.kv_a_layernorm.weight
)
self.copy(
attn.linear_proj.weight,
hf_attn.o_proj.weight,
param_type=ParamType.ROW
)
def set_selfattn_state(self, attn, hf_attn):
'''Set self-attention params.'''
# Reshape loaded weights.
tp = self.tp_size
num_heads = self.args.num_attention_heads
num_query_groups = (self.args.num_query_groups if self.args.group_query_attention else self.args.num_attention_heads)
num_querys_per_group = num_heads // num_query_groups
dim = self.args.kv_channels
assert num_heads % num_querys_per_group == 0
# copy qk norm if indeed.
if self.args.qk_layernorm:
self.copy(attn.q_layernorm.weight, hf_attn.q_norm.weight)
self.copy(attn.k_layernorm.weight, hf_attn.k_norm.weight)
# Copy weights (re-order dimensions for Megatron).
attn_proj_weight = attn.linear_qkv.weight.reshape(
(num_query_groups // tp, (2 + num_querys_per_group)*dim, -1)
)
(
q_proj_weight,
k_proj_weight,
v_proj_weight
) = torch.split(attn_proj_weight, [num_querys_per_group*dim, dim, dim], dim=1)
self.copy(q_proj_weight, hf_attn.q_proj.weight, param_type=ParamType.QKV_W)
self.copy(k_proj_weight, hf_attn.k_proj.weight, param_type=ParamType.QKV_W)
self.copy(v_proj_weight, hf_attn.v_proj.weight, param_type=ParamType.QKV_W)
self.copy(
attn.linear_proj.weight,
hf_attn.o_proj.weight,
param_type=ParamType.ROW
)
# Copy bias
if self.args.add_qkv_bias:
attn_proj_bias = attn.linear_qkv.bias.reshape(
(num_query_groups // tp, (2 + num_querys_per_group)*dim, -1)
)
q_proj_bias, k_proj_bias, v_proj_bias = torch.split(
attn_proj_bias,
[num_querys_per_group*dim, dim, dim],
dim=1
)
self.copy(q_proj_bias, hf_attn.q_proj.bias, param_type=ParamType.QKV_B)
self.copy(k_proj_bias, hf_attn.k_proj.bias, param_type=ParamType.QKV_B)
self.copy(v_proj_bias, hf_attn.v_proj.bias, param_type=ParamType.QKV_B)
def set_mlp_state(self, mlp, hf_mlp):
'''Set MLP params.'''
hidden_size = mlp.linear_fc1.weight.shape[-1]
gate_proj_weight, up_proj_weight = mlp.linear_fc1.weight.reshape(2, -1, hidden_size)
self.copy(
gate_proj_weight,
hf_mlp.gate_proj.weight,
param_type=ParamType.COLUMN
)
self.copy(
up_proj_weight,
hf_mlp.up_proj.weight,
param_type=ParamType.COLUMN
)
self.copy(
mlp.linear_fc2.weight,
hf_mlp.down_proj.weight,
param_type=ParamType.ROW
)
def set_sequential_mlp_state(self, experts, hf_experts):
'''Set MOE MLP params.'''
experts = experts.local_experts
for mg_expert_id, hf_expert_id in self._build_expert_parallel_mapping().items():
hidden_size = experts[mg_expert_id].linear_fc1.weight.shape[-1]
(
gate_proj_weight,
up_proj_weight
) = experts[mg_expert_id].linear_fc1.weight.reshape(2, -1, hidden_size)
self.copy(
gate_proj_weight,
hf_experts[hf_expert_id].gate_proj.weight,
param_type=ParamType.MOE_COLUMN
)
self.copy(
up_proj_weight,
hf_experts[hf_expert_id].up_proj.weight,
param_type=ParamType.MOE_COLUMN
)
self.copy(
experts[mg_expert_id].linear_fc2.weight,
hf_experts[hf_expert_id].down_proj.weight,
param_type=ParamType.MOE_ROW
)
def set_group_mlp_state(self, experts, hf_experts):
for mg_expert_id, hf_expert_id in self._build_expert_parallel_mapping().items():
hidden_size = getattr(experts.linear_fc1, f'weight{mg_expert_id}').shape[-1]
(
gate_proj_weight,
up_proj_weight
) = getattr(experts.linear_fc1, f'weight{mg_expert_id}').reshape(2, -1, hidden_size)
self.copy(
gate_proj_weight,
hf_experts[hf_expert_id].gate_proj.weight,
param_type=ParamType.MOE_COLUMN
)
self.copy(
up_proj_weight,
hf_experts[hf_expert_id].up_proj.weight,
param_type=ParamType.MOE_COLUMN
)
self.copy(
getattr(experts.linear_fc2, f'weight{mg_expert_id}'),
hf_experts[hf_expert_id].down_proj.weight,
param_type=ParamType.MOE_ROW
)
def set_moe_layer_state(self, moe, hf_moe):
# router
self.copy(moe.router.weight, hf_moe.gate.weight)
if moe.router.enable_expert_bias:
self.copy(moe.router.expert_bias, hf_moe.gate.e_score_correction_bias)
# experts
if self.args.moe_grouped_gemm:
# group gemm
if self.args.moe_use_legacy_grouped_gemm:
# weight1 and weight2, not impl
raise NotImplementedError("Currently only TE GroupGEMM is implemented.")
self.set_group_mlp_state(moe.experts, hf_moe.experts)
else:
# sequential
self.set_sequential_mlp_state(moe.experts, hf_moe.experts)
# shared experts
if moe.shared_experts is not None:
if moe.shared_experts.use_shared_expert_gate:
self.copy(moe.shared_experts.gate_weight, hf_moe.shared_expert_gate.weight)
hidden_size = moe.shared_experts.linear_fc1.weight.shape[-1]
gate_proj_weight, up_proj_weight = moe.shared_experts.linear_fc1.weight.reshape(2, -1, hidden_size)
self.copy(
gate_proj_weight,
hf_moe.shared_experts.gate_proj.weight,
param_type=ParamType.COLUMN
)
self.copy(
up_proj_weight,
hf_moe.shared_experts.up_proj.weight,
param_type=ParamType.COLUMN
)
self.copy(
moe.shared_experts.linear_fc2.weight,
hf_moe.shared_experts.down_proj.weight,
param_type=ParamType.ROW
)
def set_layer_state(self, layer, hf_layer):
'''Set transformer layer params.'''
if self.args.multi_latent_attention:
self.set_mla_selfattn_state(layer.self_attention, hf_layer.self_attn)
self.copy(layer.input_layernorm.weight, hf_layer.input_layernorm.weight)
else:
self.set_selfattn_state(layer.self_attention, hf_layer.self_attn)
self.copy(layer.self_attention.linear_qkv.layer_norm_weight, hf_layer.input_layernorm.weight)
if hasattr(layer.mlp, 'router'):
self.set_moe_layer_state(layer.mlp, hf_layer.mlp)
self.copy(layer.pre_mlp_layernorm.weight, hf_layer.post_attention_layernorm.weight)
else:
self.set_mlp_state(layer.mlp, hf_layer.mlp)
self.copy(layer.mlp.linear_fc1.layer_norm_weight, hf_layer.post_attention_layernorm.weight)
def check_and_save(self, output_dir):
sharded_info = split_torch_state_dict_into_shards(
self._hfmodel.state_dict(),
max_shard_size=self.max_shard_size
)
global_shape = {self._hf_params_key_to_id[k]: v.shape for k, v in self._hfmodel.state_dict().items()}
# select local bucket(s) for each rank
n_savers = self.num_savers
local_buckets = []
max_n_local_buckets = 1
if not sharded_info.is_sharded and self.rank == 0:
local_buckets = list(sharded_info.filename_to_tensors.keys())
else:
n_buckets = len(sharded_info.filename_to_tensors)
rank = self.rank
n_local_buckets = n_buckets // n_savers
remainder = n_buckets % n_savers
max_n_local_buckets = n_local_buckets + 1
if remainder == 0:
start = rank * n_local_buckets
max_n_local_buckets -= 1
elif rank < remainder:
n_local_buckets += 1
start = rank * n_local_buckets
else:
start = rank * n_local_buckets + remainder
local_buckets = list(sharded_info.filename_to_tensors.keys())[start:start + n_local_buckets]
if rank == 0:
index = {
"metadata": sharded_info.metadata,
"weight_map": sharded_info.tensor_to_filename,
}
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "model.safetensors.index.json"), "w") as f:
f.write(json.dumps(index, indent=2))
self._collect_dist_info()
# In each iteration, all ranks save at most one local bucket
for bucket_idx in range(max_n_local_buckets):
required_keys = []
if bucket_idx < len(local_buckets):
bucket_name = local_buckets[bucket_idx]
required_keys: List[str] = sharded_info.filename_to_tensors[bucket_name]
# build send/recv op across all ranks
data, buffers, send_param_ids, recv_param_ids, ops = self._build_p2p_ops(required_keys)
# run data sync
if self.debug:
logging.info(f"[Iters {bucket_idx} RANK {self.rank}] starts synchronizing parameters with other ranks...")
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
if self.debug:
for op in ops:
if op.op == dist.isend:
logging.info(f"[Iters {bucket_idx} RANK {self.rank}] ({self.rank} -> {op.peer}) with {op.tensor.numel() * op.tensor.dtype.itemsize / 2 ** 20} MiB.")
else:
logging.info(f"[Iters {bucket_idx} RANK {self.rank}] ({op.peer} -> {self.rank}) with {op.tensor.numel() * op.tensor.dtype.itemsize / 2 ** 20} MiB.")
for req in reqs:
req.wait()
if self.debug:
logging.info(f"[Iters {bucket_idx} RANK {self.rank}] finishes synchronizing")
for remote_rank, param_ids in recv_param_ids.items():
for param_id, tensor in zip(param_ids, self._unpack_from_buffer(buffers[remote_rank], param_ids)):
data[param_id][remote_rank] = tensor
# apply merge function on the results
# data: Dict[param_id, Dict[rank_id, tensor]]
output_data = dict()
for param_id, data_dict in data.items():
key = self._id_to_hf_params_key[param_id]
param_type = ParamType(int(self._merge_type[param_id]))
if param_type == ParamType.NULL:
raise ValueError(f"ParamType.NULL found on {key}.")
try:
output_data[key] = self._merge_data(param_type, data_dict)
except ParamMergeError as e:
raise ValueError(f"Merge Error on key {key}: {e}")
if output_data[key].shape != global_shape[param_id]:
raise ValueError(f"Unexpected shape on {key}. Expected: {global_shape[param_id]}, but {output_data[key].shape}")
del data
# save safetensor files
if bucket_idx < len(local_buckets):
if not self.dryrun:
safe_save_file(
output_data,
os.path.join(output_dir, local_buckets[bucket_idx]),
metadata={"format": "pt"},
)
print(f"[Iters {bucket_idx} RANK {self.rank}] {local_buckets[bucket_idx]} is saved.")
if self.debug:
logging.debug(f"[Iters {bucket_idx} RANK {self.rank}] joined")
dist.barrier()
def _collect_dist_info(self):
# Collect following metadatas:
# param_id --> source_rank
self._has_param = torch.zeros(
[self.world_size, len(self._hf_params_key_to_id)], dtype=torch.bool, device=self.device
)
for param_id in self._local_params.keys():
self._has_param[self.rank][param_id] = True
dist.all_gather_into_tensor(self._has_param, self._has_param[self.rank])
self._has_param = self._has_param.T
# param_id --> tensor_shape Dict[int, Tuple[int, ...]]
# param_id --> tensor_dtype Dict[int, dtype]
for param_id, param in self._local_params.items():
self._tensor_shape[param_id] = param.shape
self._tensor_dtype[param_id] = param.dtype
# collect across ranks
tensor_shapes = [None] * self.world_size
tensor_dtypes = [None] * self.world_size
dist.all_gather_object(tensor_shapes, self._tensor_shape)
dist.all_gather_object(tensor_dtypes, self._tensor_dtype)
# NOTE: merge them together
for rank, remote_shape in enumerate(tensor_shapes):
if rank == self.rank:
continue
for remote_key, shape in remote_shape.items():
if remote_key not in self._tensor_shape:
self._tensor_shape[remote_key] = shape
elif shape != self._tensor_shape[remote_key]:
raise ValueError(
f"Find mismatched shape on local rank {self.rank} and remote rank {rank}, local shape: {self._tensor_shape[remote_key]}; remote shape: {shape}"
)
for rank, remote_dtype in enumerate(tensor_dtypes):
if rank == self.rank:
continue
for remote_key, dtype in remote_dtype.items():
if remote_key not in self._tensor_dtype:
self._tensor_dtype[remote_key] = dtype
elif dtype != self._tensor_dtype[remote_key]:
raise ValueError(
f"Find mismatched dtype on local rank {self.rank} and remote rank {rank}, local shape: {self._tensor_dtype[remote_key]}; remote shape: {dtype}"
)
# merge_type
global_merge_type = torch.zeros([self.world_size, self.hf_size], dtype=self._merge_type.dtype, device=self.device)
dist.all_gather_into_tensor(global_merge_type, self._merge_type)
for remote_rank_id, remote_merge_type in enumerate(global_merge_type):
if self.debug:
and_mask = torch.logical_and(remote_merge_type > 0, self._merge_type > 0)
if (self._merge_type[and_mask] != remote_merge_type[and_mask]).any():
param_id = -1
for param_id in range(self.hf_size):
if (
self._merge_type[param_id] > 0 and
remote_merge_type[param_id] > 0 and
self._merge_type[param_id] != remote_merge_type[param_id]
):
break
key = self._id_to_hf_params_key[param_id]
raise ValueError(f"Find mismatched merge_type between local rank {self.rank} and remote rank {remote_rank_id} on key {key}")
self._merge_type[remote_merge_type > 0] = remote_merge_type[remote_merge_type > 0]
def _unpack_from_buffer(self, buffer: torch.Tensor, param_ids: List[int]) -> List[torch.Tensor]:
start = 0
datas = []
for param_id in param_ids:
shape = self._tensor_shape[param_id]
dtype = self._tensor_dtype[param_id]
offset = shape.numel() * dtype.itemsize
datas.append(buffer[start:start + offset].view(dtype).view(shape).clone())
start += offset
if start != buffer.numel():
raise ValueError(f"Expect {start} bytes from remote, but got {buffer.numel()} bytes!")
return datas
def _pack_into_byte_buffer(self, tensors: List[torch.Tensor]) -> torch.Tensor:
return torch.cat([t.flatten().view(torch.uint8) for t in tensors])
def _build_p2p_ops(self, required_keys: List[str]):
required_ids = torch.zeros(
[self.world_size, self.hf_size], dtype=torch.bool, device=self.device
)
for k in required_keys:
required_ids[self.rank][self._hf_params_key_to_id[k]] = True
dist.all_gather_into_tensor(required_ids, required_ids[self.rank])
send_ops = []
if self.debug:
# (param_id, src_rank, dst_rank)
send_recv_pattern = torch.zeros([self.hf_size, self.world_size, self.world_size], dtype=torch.int, device=self.device)
send_param_ids = defaultdict(list)
datas = defaultdict(list)
# NOTE: for rank i to rank j, send params by ascending order
# NOTE: to avoid hangs observed in multi-nodes scenarios, we merge tensors sent to same remote rank into a large uint8 tensor
for param_id, has_data in enumerate(self._has_param[:, self.rank]):
if not has_data:
continue
# group by remote_id
for remote_rank, should_send in enumerate(required_ids[:, param_id]):
if not should_send or remote_rank == self.rank:
continue
# NOTE: for each receiver, send param in ascending order by id
data = self._local_params[param_id]
if data.device != self.device:
logging.warning(f"Find unexpected device {data.device} on key {self._id_to_hf_params_key[param_id]}, moving to {self.device}")
data = data.to(self.device)
if data.dtype != self._tensor_dtype[param_id]:
raise ValueError(f"Get mismatched data type on key {self._id_to_hf_params_key[param_id]}")
datas[remote_rank].append(data)
send_param_ids[remote_rank].append(param_id)
if self.debug:
send_recv_pattern[param_id, self.rank, remote_rank] += 1
for remote_rank, raw_data in datas.items():
if len(raw_data) > 0:
send_ops.append(dist.P2POp(
dist.isend,
self._pack_into_byte_buffer(raw_data),
peer=remote_rank,
tag=self.rank * self.world_size + remote_rank # (sender_rank, receiver_rank) ignored in NCCL
))
recv_ops = []
collected_data = defaultdict(dict)
buffer_size = [0] * self.world_size
recv_param_ids = defaultdict(list)
# NOTE: for rank i to rank j, recv params by ascending order
for param_id, is_required in enumerate(required_ids[self.rank]):
if not is_required:
continue
for remote_rank, has_data in enumerate(self._has_param[param_id]):
if not has_data:
continue
if remote_rank == self.rank:
collected_data[param_id][remote_rank] = self._local_params[param_id]
else:
recv_param_ids[remote_rank].append(param_id)
shape = self._tensor_shape[param_id]
dtype = self._tensor_dtype[param_id]
buffer_size[remote_rank] += shape.numel() * dtype.itemsize
if self.debug:
send_recv_pattern[param_id, remote_rank, self.rank] -= 1
buffers = [None] * self.world_size
for remote_rank, rank_size in enumerate(buffer_size):
if rank_size == 0:
continue
buffers[remote_rank] = torch.empty(rank_size, dtype=torch.uint8, device=self.device)
recv_ops.append(dist.P2POp(
dist.irecv,
buffers[remote_rank],
peer=remote_rank,
tag=remote_rank * self.world_size + self.rank # (sender_rank, receiver_rank) ignored in NCCL
))
if self.debug:
dist.all_reduce(send_recv_pattern)
if send_recv_pattern.sum() != 0:
for param_id, pattern_per_param in enumerate(send_recv_pattern):
if pattern_per_param.sum() != 0:
raise ValueError(f"Mismatched send/recv ops detected on key {self._id_to_hf_params_key[param_id]}: {pattern_per_param}.")
raise ValueError("Mismatched send/recv ops detected.")
logging.debug(f"[RANK {self.rank}] {len(send_ops)} send op & {len(recv_ops)} recv op.")
return collected_data, buffers, send_param_ids, recv_param_ids, (send_ops + recv_ops)
def _merge_data(self, merge_type: ParamType, tensor_dict) -> torch.Tensor:
"""Merge ShardedTensor collected across the group on this rank.
If DP or EDP is larger than 1, tensor_dict could contains multiple data
copies and an data deduplication isrequired.
Args:
merge_type (ParamType): Themerge policy be applied on the
ShardedTensor
tensor_dict (Dict[int, torch.Tensor]): The collected dict of
ShardedTensors, mapping global_rank to tensor data.
Returns:
torch.Tensor: The merged tensor
"""
# tensor_dict: Dict[remote_rank, torch.Tensor]
def merge_along_axis(axis, tensor_dict, is_expert: bool = False):
global_ranks = torch.tensor(list(tensor_dict.keys()), dtype=torch.long, device=self.device) # (N, )
ranks = self._rank_mapping.index_select(0, global_ranks) # (N, 6)
if self.debug:
if is_expert and ranks[:, 5].any():
raise ParamMergeError("Unexpected expert parameter data from non-zero expert dp rank")
if not is_expert and ranks[:, 4].any():
raise ParamMergeError("Unexpected parameter data from non-zero dp rank")
if ranks[:, 1].max() != ranks[:, 1].min():
raise ParamMergeError("Unexpected parameter data from multiple pp ranks")
if is_expert and ranks[:, 3].max() != ranks[:, 3].min():
raise ParamMergeError("Unexpected expert parameter data from multiple ep ranks")
def deduplicate_and_sort(tensor_dict, rank_group: int):
tensors = dict()
for global_rank, data in tensor_dict.items():
rank = self._rank_mapping[global_rank][rank_group]
tensors[int(rank)] = data
return [item[1] for item in sorted(tensors.items())]
tensors = list(tensor_dict.values())[:1]
if is_expert:
tensors = deduplicate_and_sort(tensor_dict, 2)
else:
tensors = deduplicate_and_sort(tensor_dict, 0)
return torch.cat(tensors, dim=axis)
def no_merge_func(tensor_dict):
return list(tensor_dict.values())[0]
def merge_qkv(is_bias, tensor_dict):
res = merge_along_axis(0, tensor_dict)
if is_bias:
assert res.shape[-1] == 1
return res.flatten()
return res.flatten(0, 1)
merge_func_mapping = {
ParamType.MOE_COLUMN: partial(merge_along_axis, 0, is_expert=True),
ParamType.MOE_ROW: partial(merge_along_axis, 1, is_expert=True),
ParamType.COLUMN: partial(merge_along_axis, 0),
ParamType.ROW: partial(merge_along_axis, 1),
ParamType.QKV_W: partial(merge_qkv, False),
ParamType.QKV_B: partial(merge_qkv, True),
ParamType.UNIQUE: no_merge_func,
}
return merge_func_mapping[merge_type](tensor_dict)