in chatlearn/synchronizer/megatron_vllm.py [0:0]
def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group):
"""
This function is applicable for synchronizing parameters from QWen with HEP enabled
to vLLM. In HEP, routed experts are split into a total number of EP size * TP size.
Thus, the function will all-to-all across EP size * TP size routed experts.
"""
if self.sync_map._to_alltoall_routed_experts_dict is None:
return params_to_sync, False
to_alltoall_routed_experts_dict = self.sync_map._to_alltoall_routed_experts_dict
layer_re = to_alltoall_routed_experts_dict["layer_re"]
to_regroup_modules_list = to_alltoall_routed_experts_dict["modules"]
m = layer_re.match(name)
if m is None:
return params_to_sync, False
op_name = m.group(2)
if op_name in to_regroup_modules_list:
tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
ep_size = self.src_module_args.args_dict["moe_expert_model_parallel_size"]
hep_size = tp_size * ep_size
moe_num_experts = self.src_module_args.args_dict["moe_num_experts"]
local_num_experts = moe_num_experts // hep_size
hidden_size = self.src_module_args.args_dict["hidden_size"]
if "dense_h_to_4h" in op_name:
# w13_weight
# regroup among difference tp slices
param = params_to_sync.view((moe_num_experts, -1, hidden_size))
param = param.reshape((local_num_experts * 2, -1, hidden_size))
params = list(param.chunk(hep_size, dim=1))
# reorder w1 and w3
params_list = []
while params:
param = params.pop(0)
param = param.reshape(param.shape[0] // 2, -1, hidden_size)
param_right, param_left = param.chunk(2, dim=1)
del param
param = torch.cat([param_left, param_right], dim=1)
del param_left
del param_right
params_list.append(param)
del params_to_sync
output = [
torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device)
for i in range(hep_size)
]
torch.distributed.all_to_all(output, params_list, group=comm_group)
del params_list
params_to_sync = torch.cat(output, dim=0).contiguous()
del output
else:
# w2_weight
param = params_to_sync.view((local_num_experts, -1, hidden_size))
params = list(param.chunk(hep_size, dim=1))
params_list = [ele.contiguous() for ele in params]
del param
del params
del params_to_sync
output = [
torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device)
for i in range(hep_size)
]
torch.distributed.all_to_all(output, params_list, group=comm_group)
del params_list
params_to_sync = torch.cat(output, dim=0).transpose(1, 2).contiguous()
del output
return params_to_sync, True
else:
return params_to_sync, False