def allgather_routed_experts_from_hep()

in chatlearn/synchronizer/megatron_vllm.py [0:0]


    def allgather_routed_experts_from_hep(self, name, params_to_sync, group_name, tp_rank):
        """
        This function is applicable for synchronizing parameters from QWen with HEP enabled
        to vLLM. In HEP, routed experts are split into a total number of EP size * TP size.
        Thus, the function will all-gather across EP size * TP size routed experts and slice
        them to TP size partitions.
        """
        if self.sync_map._to_allgather_routed_experts_dict is None:
            return params_to_sync, False

        to_allgather_routed_experts_dict = self.sync_map._to_allgather_routed_experts_dict
        layer_re = to_allgather_routed_experts_dict["layer_re"]
        to_regroup_modules_list = to_allgather_routed_experts_dict["modules"]

        m = layer_re.match(name)
        if m is None:
            return params_to_sync, False

        op_name = m.group(2)
        if op_name in to_regroup_modules_list:
            tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
            ep_size = self.src_module_args.args_dict["moe_expert_model_parallel_size"]
            hep_size = tp_size * ep_size
            moe_num_experts = self.src_module_args.args_dict["moe_num_experts"]
            local_num_experts = moe_num_experts // hep_size
            hidden_size = self.src_module_args.args_dict["hidden_size"]
            output_tensor_list = [
                torch.empty(size=params_to_sync.shape, dtype=params_to_sync.dtype, device=params_to_sync.device)
                for _ in range(hep_size)
            ]
            col.allgather(output_tensor_list, params_to_sync, group_name)
            del params_to_sync
            val_list = []
            if "dense_h_to_4h" in op_name:
                # w13_weight
                while output_tensor_list:
                    params = output_tensor_list.pop(0)
                    # regroup among difference tp slices
                    params = params.view((moe_num_experts, -1, hidden_size)).contiguous()
                    params = params.reshape((local_num_experts * 2, -1, hidden_size))
                    params = params.chunk(tp_size, dim=1)[tp_rank]
                    # reorder w1 and w3
                    params = params.reshape(params.shape[0] // 2, -1, hidden_size)
                    params_right, params_left = params.chunk(2, dim=1)
                    del params
                    params = torch.cat([params_left, params_right], dim=1)
                    del params_left
                    del params_right
                    val_list.append(params)
                params_to_sync = torch.cat(val_list, dim=0).contiguous()
            else:
                # w2_weight
                while output_tensor_list:
                    params = output_tensor_list.pop(0)
                    params = params.reshape((local_num_experts, -1, hidden_size))
                    chunked_params = params.chunk(tp_size, dim=1)[tp_rank].contiguous()
                    del params
                    val_list.append(chunked_params)
                params_to_sync = torch.cat(val_list, dim=0).transpose(1, 2).contiguous()
            del val_list
            return params_to_sync, True
        else:
            return params_to_sync, False