in chatlearn/synchronizer/megatron_vllm.py [0:0]
def allgather_routed_experts_from_hep(self, name, params_to_sync, group_name, tp_rank):
"""
This function is applicable for synchronizing parameters from QWen with HEP enabled
to vLLM. In HEP, routed experts are split into a total number of EP size * TP size.
Thus, the function will all-gather across EP size * TP size routed experts and slice
them to TP size partitions.
"""
if self.sync_map._to_allgather_routed_experts_dict is None:
return params_to_sync, False
to_allgather_routed_experts_dict = self.sync_map._to_allgather_routed_experts_dict
layer_re = to_allgather_routed_experts_dict["layer_re"]
to_regroup_modules_list = to_allgather_routed_experts_dict["modules"]
m = layer_re.match(name)
if m is None:
return params_to_sync, False
op_name = m.group(2)
if op_name in to_regroup_modules_list:
tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
ep_size = self.src_module_args.args_dict["moe_expert_model_parallel_size"]
hep_size = tp_size * ep_size
moe_num_experts = self.src_module_args.args_dict["moe_num_experts"]
local_num_experts = moe_num_experts // hep_size
hidden_size = self.src_module_args.args_dict["hidden_size"]
output_tensor_list = [
torch.empty(size=params_to_sync.shape, dtype=params_to_sync.dtype, device=params_to_sync.device)
for _ in range(hep_size)
]
col.allgather(output_tensor_list, params_to_sync, group_name)
del params_to_sync
val_list = []
if "dense_h_to_4h" in op_name:
# w13_weight
while output_tensor_list:
params = output_tensor_list.pop(0)
# regroup among difference tp slices
params = params.view((moe_num_experts, -1, hidden_size)).contiguous()
params = params.reshape((local_num_experts * 2, -1, hidden_size))
params = params.chunk(tp_size, dim=1)[tp_rank]
# reorder w1 and w3
params = params.reshape(params.shape[0] // 2, -1, hidden_size)
params_right, params_left = params.chunk(2, dim=1)
del params
params = torch.cat([params_left, params_right], dim=1)
del params_left
del params_right
val_list.append(params)
params_to_sync = torch.cat(val_list, dim=0).contiguous()
else:
# w2_weight
while output_tensor_list:
params = output_tensor_list.pop(0)
params = params.reshape((local_num_experts, -1, hidden_size))
chunked_params = params.chunk(tp_size, dim=1)[tp_rank].contiguous()
del params
val_list.append(chunked_params)
params_to_sync = torch.cat(val_list, dim=0).transpose(1, 2).contiguous()
del val_list
return params_to_sync, True
else:
return params_to_sync, False