def alltoall_routed_experts_from_hep()

in chatlearn/synchronizer/megatron_vllm.py [0:0]


    def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group):
        """
        This function is applicable for synchronizing parameters from QWen with HEP enabled
        to vLLM. In HEP, routed experts are split into a total number of EP size * TP size.
        Thus, the function will all-to-all across EP size * TP size routed experts.
        """
        if self.sync_map._to_alltoall_routed_experts_dict is None:
            return params_to_sync, False

        to_alltoall_routed_experts_dict = self.sync_map._to_alltoall_routed_experts_dict
        layer_re = to_alltoall_routed_experts_dict["layer_re"]
        to_regroup_modules_list = to_alltoall_routed_experts_dict["modules"]

        m = layer_re.match(name)
        if m is None:
            return params_to_sync, False

        op_name = m.group(2)
        if op_name in to_regroup_modules_list:
            tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
            ep_size = self.src_module_args.args_dict["moe_expert_model_parallel_size"]
            hep_size = tp_size * ep_size
            moe_num_experts = self.src_module_args.args_dict["moe_num_experts"]

            local_num_experts = moe_num_experts // hep_size
            hidden_size = self.src_module_args.args_dict["hidden_size"]
            if "dense_h_to_4h" in op_name:
                # w13_weight
                # regroup among difference tp slices
                param = params_to_sync.view((moe_num_experts, -1, hidden_size))
                param = param.reshape((local_num_experts * 2, -1, hidden_size))
                params = list(param.chunk(hep_size, dim=1))
                # reorder w1 and w3
                params_list = []
                while params:
                    param = params.pop(0)
                    param = param.reshape(param.shape[0] // 2, -1, hidden_size)
                    param_right, param_left = param.chunk(2, dim=1)
                    del param
                    param = torch.cat([param_left, param_right], dim=1)
                    del param_left
                    del param_right
                    params_list.append(param)
                del params_to_sync
                output = [
                    torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device)
                    for i in range(hep_size)
                ]
                torch.distributed.all_to_all(output, params_list, group=comm_group)
                del params_list
                params_to_sync = torch.cat(output, dim=0).contiguous()
                del output
            else:
                # w2_weight
                param = params_to_sync.view((local_num_experts, -1, hidden_size))
                params = list(param.chunk(hep_size, dim=1))
                params_list = [ele.contiguous() for ele in params]
                del param
                del params
                del params_to_sync
                output = [
                    torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device)
                    for i in range(hep_size)
                ]
                torch.distributed.all_to_all(output, params_list, group=comm_group)
                del params_list
                params_to_sync = torch.cat(output, dim=0).transpose(1, 2).contiguous()
                del output
            return params_to_sync, True
        else:
            return params_to_sync, False