# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import weakref
from contextlib import contextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Any

import torch
import vllm.distributed.parallel_state as parallel_state
import vllm.envs as envs
from vllm.attention.layer import Attention
from vllm.config import ModelConfig, ParallelConfig
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.distributed.parallel_state import (init_model_parallel_group,
                                             get_world_group,
                                             destroy_model_parallel,
                                             destroy_distributed_environment)
from vllm.executor.multiproc_worker_utils import (
    set_multiprocessing_worker_envs)
from vllm.utils import get_distributed_init_method, get_open_port
from vllm.v1.executor.abstract import FailureCallback
from vllm.v1.executor.multiproc_executor import (MultiprocExecutor, WorkerProc,
                                                 UnreadyWorkerProcHandle)
from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname
from vllm.compilation.backends import PiecewiseCompileInterpreter
from vllm.model_executor.layers.fused_moe import FusedMoE

from arctic_inference.patching import ArcticPatch


def apply_shift_parallel_patches():
    UlyssesModelConfigPatch.apply_patch()
    UlyssesParallelStatePatch.apply_patch()
    UlyssesWorkerProcPatch.apply_patch()
    UlyssesMultiprocExecutorPatch.apply_patch()
    UlyssesAttentionPatch.apply_patch()
    PiecewiseCompileInterpreterPatch.apply_patch()
    UlyssesFusedMoEPatch.apply_patch()


class UlyssesModelConfigPatch(ArcticPatch[ModelConfig]):

    _orig_get_num_kv_heads = ModelConfig.get_num_kv_heads
    _orig_get_num_attention_heads = ModelConfig.get_num_attention_heads

    def get_num_kv_heads(self: ModelConfig,
                         parallel_config: ParallelConfig) -> int:
        num_kv_heads = self._orig_get_num_kv_heads(parallel_config)
        sp_size = parallel_config.ulysses_sequence_parallel_size
        return max(1, num_kv_heads // sp_size)

    def get_num_attention_heads(self: ModelConfig,
                                parallel_config: ParallelConfig) -> int:
        num_heads = self._orig_get_num_attention_heads(parallel_config)
        sp_size = parallel_config.ulysses_sequence_parallel_size
        return max(1, num_heads // sp_size)

    def get_layers_start_end_indices(
            self, parallel_config: "ParallelConfig") -> tuple[int, int]:
        from vllm.distributed.utils import get_pp_indices
        if (self.hf_text_config.model_type == "deepseek_mtp"
                or self.hf_config.model_type == "mimo_mtp"):
            total_num_hidden_layers = getattr(self.hf_text_config,
                                              "num_nextn_predict_layers", 0)
        else:
            total_num_hidden_layers = getattr(self.hf_text_config,
                                              "num_hidden_layers", 0)
        # the layout order is: DP x PP x SP x TP
        pp_rank = (parallel_config.rank //
                   (parallel_config.tensor_parallel_size *
                    parallel_config.ulysses_sequence_parallel_size)
                   ) % parallel_config.pipeline_parallel_size
        pp_size = parallel_config.pipeline_parallel_size
        start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
        return start, end


class UlyssesParallelStatePatch(ArcticPatch[parallel_state]):

    _SP = None
    _SP_TP = None
    _SP_AA = None
    _SP_AG = None
    # Rationale for SP_AA and SP_AG groups:
    # When num_kv_heads > SP, the kv heads are distributed and replicated as in TP.
    # To implement the logic, the distributed kv heads are exchanged with a local
    # all-to-all within SP_AA group followed by an local all-gather within SP_AG
    # group. The SP_AA and SP_AG groups partitions the SP group into two orthogonal
    # sub-groups and will not be initialized if max(1, num_kv_heads / TP) < SP.
    # See the figure in PR #126 https://github.com/snowflakedb/ArcticInference/pull/126

    def initialize_model_parallel(
        tensor_model_parallel_size: int = 1,
        pipeline_model_parallel_size: int = 1,
        backend: Optional[str] = None,
    ) -> None:
        """
        Initialize model parallel groups.

        Arguments:
            tensor_model_parallel_size: number of GPUs used for tensor model
                parallelism.
            pipeline_model_parallel_size: number of GPUs used for pipeline model
                parallelism.

        Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
        use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
        the model pipeline. The present function will
        create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
            4 tensor model-parallel groups:
                [g0, g1], [g2, g3], [g4, g5], [g6, g7]
            2 pipeline model-parallel groups:
                [g0, g2, g4, g6], [g1, g3, g5, g7]
        Note that for efficiency, the caller should make sure adjacent ranks
        are on the same DGX box. For example if we are using 2 DGX-1 boxes
        with a total of 16 GPUs, rank 0 to 7 belong to the first box and
        ranks 8 to 15 belong to the second box.
        """
        from vllm.distributed.parallel_state import _DP, _EP, _PP, _TP
        # Get world size and rank. Ensure some consistencies.
        assert torch.distributed.is_initialized()
        world_size: int = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        backend = backend or torch.distributed.get_backend(
            get_world_group().device_group)

        data_parallel_size = 1
        from vllm.config import get_current_vllm_config
        config = get_current_vllm_config()
        if config is not None:
            data_parallel_size = config.parallel_config.data_parallel_size

        sequence_parallel_size = \
            config.parallel_config.ulysses_sequence_parallel_size

        # the layout order is: ExternalDP x DP x PP x SP x TP
        # ExternalDP is the data parallel group that is not part of the model,
        # every dp rank can generate independently (in verl integration).
        # DP is the data parallel group that is part of the model,
        # all the ranks in the same DP group should generate simultaneously,
        # i.e. the `generate` call in the same DP group should be called together,
        # otherwise it will cause deadlock.
        # to get group_ranks for each dimension, transpose that dimension to the
        # last dimension, then reshape to 2D, then unbind the last dimension
        all_ranks = torch.arange(world_size).reshape(
            -1, data_parallel_size, pipeline_model_parallel_size,
            sequence_parallel_size, tensor_model_parallel_size)  # noqa

        # Build the tensor model-parallel groups.
        assert _TP is None, ("tensor model parallel group is already initialized")
        group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
        group_ranks = [x.tolist() for x in group_ranks]
        TP_group_ranks = group_ranks
        # message queue broadcaster is only used in tensor model parallel group
        _TP = init_model_parallel_group(group_ranks,
                                        get_world_group().local_rank,
                                        backend,
                                        use_message_queue_broadcaster=True,
                                        group_name="tp")

        # Build the pipeline model-parallel groups.
        assert _PP is None, (
            "pipeline model parallel group is already initialized")
        group_ranks = all_ranks.transpose(2, 4).reshape(
            -1, pipeline_model_parallel_size).unbind(0)
        group_ranks = [x.tolist() for x in group_ranks]
        PP_group_ranks = group_ranks
        _PP = init_model_parallel_group(group_ranks,
                                        get_world_group().local_rank,
                                        backend,
                                        group_name="pp")

        assert _DP is None, ("data parallel group is already initialized")
        group_ranks = all_ranks.transpose(1,
                                          4).reshape(-1,
                                                     data_parallel_size).unbind(0)
        group_ranks = [x.tolist() for x in group_ranks]
        DP_group_ranks = group_ranks
        _DP = init_model_parallel_group(group_ranks,
                                        get_world_group().local_rank,
                                        backend,
                                        group_name="dp")

        assert _EP is None, ("expert parallel group is already initialized")
        group_ranks = all_ranks.transpose(1, 3).reshape(
            -1, data_parallel_size * tensor_model_parallel_size).unbind(0)
        group_ranks = [x.tolist() for x in group_ranks]
        EP_group_ranks = group_ranks
        _EP = init_model_parallel_group(group_ranks,
                                        get_world_group().local_rank,
                                        backend,
                                        group_name="ep")

        # Build the sequence parallel groups.
        assert parallel_state._SP is None, (
            "sequence parallel group is already initialized")
        group_ranks = all_ranks.transpose(3, 4).reshape(
            -1, sequence_parallel_size).unbind(0)
        group_ranks = [x.tolist() for x in group_ranks]
        SP_group_ranks = group_ranks
        _SP = init_model_parallel_group(group_ranks,
                                        get_world_group().local_rank,
                                        backend,
                                        group_name="sp")

        # Build full-TP groups for ShiftParallel
        shift_parallel_size = (tensor_model_parallel_size *
                               sequence_parallel_size)
        assert parallel_state._SP_TP is None, (
            "full-TP group is already initialized")
        # transpose(3, 4) for obtaining the correct attn head order
        group_ranks = all_ranks.transpose(3, 4).reshape(
            -1, shift_parallel_size).unbind(0)
        group_ranks = [x.tolist() for x in group_ranks]
        SP_TP_group_ranks = group_ranks
        _SP_TP = init_model_parallel_group(group_ranks,
                                           get_world_group().local_rank,
                                           backend,
                                           group_name="sp_tp")

        parallel_state.logger.info(
            "rank %s in world size %s is assigned as DP rank %s, PP rank %s, "
            "TP rank %s, EP rank %s, SP rank %s, SP_TP rank %s", rank,
            world_size, _DP.rank_in_group, _PP.rank_in_group,
            _TP.rank_in_group, _EP.rank_in_group, _SP.rank_in_group,
            _SP_TP.rank_in_group)

        parallel_state._TP = _TP
        parallel_state._PP = _PP
        parallel_state._SP = _SP
        parallel_state._SP_TP = _SP_TP
        parallel_state._DP = _DP

        # check if SP requires kv replication
        num_kv_heads = config.model_config._orig_get_num_kv_heads(config.parallel_config)
        if num_kv_heads < sequence_parallel_size:

            # divide SP group into two orthogonal sub-groups:
            sp_aa_size = num_kv_heads
            sp_ag_size = sequence_parallel_size // num_kv_heads
            all_ranks_ = torch.arange(world_size).reshape(
            -1, data_parallel_size, pipeline_model_parallel_size,
            sp_aa_size, sp_ag_size, tensor_model_parallel_size)

            group_ranks = all_ranks_.transpose(3, 5).reshape(
                -1, sp_aa_size).unbind(0)
            group_ranks = [x.tolist() for x in group_ranks]
            SP_AA_group_ranks = group_ranks
            # SP_AA group is used for all-to-all communication of kv heads
            _SP_AA = init_model_parallel_group(group_ranks,
                                            get_world_group().local_rank,
                                            backend,
                                            group_name="sp_aa")
            
            group_ranks = all_ranks_.transpose(4, 5).reshape(
                -1, sp_ag_size).unbind(0)
            group_ranks = [x.tolist() for x in group_ranks]
            SP_AG_group_ranks = group_ranks
            # SP_AG group is used for all-gather communication of kv heads
            _SP_AG = init_model_parallel_group(group_ranks,
                                            get_world_group().local_rank,
                                            backend,
                                            group_name="sp_ag")

            parallel_state._SP_AA = _SP_AA
            parallel_state._SP_AG = _SP_AG

        if get_world_group().local_rank == 0:
            parallel_state.logger.info(
                    f"UlyssesParallelStatePatch initialized:\n"
                    f"  PP {_PP.world_size} ranks {PP_group_ranks}\n"
                    f"  TP {_TP.world_size} ranks {TP_group_ranks}\n"
                    f"  SP {_SP.world_size} ranks {SP_group_ranks}\n"
                    f"  DP {_DP.world_size} ranks {DP_group_ranks}\n"
                    f"  EP {_EP.world_size} ranks {EP_group_ranks}\n"
                    f"  SP_TP {_SP_TP.world_size} ranks {SP_TP_group_ranks}")
            if num_kv_heads < sequence_parallel_size:
                parallel_state.logger.info(
                    f"  SP_AA {parallel_state._SP_AA.world_size} ranks {SP_AA_group_ranks}\n"
                    f"  SP_AG {parallel_state._SP_AG.world_size} ranks {SP_AG_group_ranks}\n")

    @contextmanager
    def graph_capture(device: torch.device):
        """
        `graph_capture` is a context manager which should surround the code that
        is capturing the CUDA graph. Its main purpose is to ensure that the
        some operations will be run after the graph is captured, before the graph
        is replayed. It returns a `GraphCaptureContext` object which contains the
        necessary data for the graph capture. Currently, it only contains the
        stream that the graph capture is running on. This stream is set to the
        current CUDA stream when the context manager is entered and reset to the
        default stream when the context manager is exited. This is to ensure that
        the graph capture is running on a separate stream from the default stream,
        in order to explicitly distinguish the kernels to capture
        from other kernels possibly launched on background in the default stream.
        """
        from vllm.distributed.parallel_state import GraphCaptureContext
        context = GraphCaptureContext(torch.cuda.Stream(device=device))
        with parallel_state._TP.graph_capture(context), parallel_state._PP.graph_capture(
                context), parallel_state._SP_TP.graph_capture(context):
            yield context


class UlyssesWorkerProcPatch(ArcticPatch[WorkerProc]):

    def destroy_model_parallel(self):
        from vllm.distributed.parallel_state import _SP, _SP_TP, _SP_AA, _SP_AG
        if _SP:
            _SP.destroy()
        _SP = None
        if _SP_TP:
            _SP_TP.destroy()
        _SP_TP = None
        if _SP_AA:
            _SP_AA.destroy()
        _SP_AA = None
        if _SP_AG:
            _SP_AG.destroy()
        _SP_AG = None

    def shutdown(self):
        self.rpc_broadcast_mq = None
        self.worker_response_mq = None
        destroy_model_parallel()
        # destroy Ulysses communicators here
        self.destroy_model_parallel()
        destroy_distributed_environment()


class UlyssesMultiprocExecutorPatch(ArcticPatch[MultiprocExecutor]):

    def _init_executor(self) -> None:
        # Call self.shutdown at exit to clean up
        # and ensure workers will be terminated.
        self._finalizer = weakref.finalize(self, self.shutdown)
        self.is_failed = False
        self.shutdown_event = threading.Event()
        self.failure_callback: Optional[FailureCallback] = None
        self.io_thread_pool: Optional[ThreadPoolExecutor] = None

        self.world_size = self.parallel_config.world_size
        tensor_parallel_size = self.parallel_config.tensor_parallel_size
        pp_parallel_size = self.parallel_config.pipeline_parallel_size
        sp_parallel_size = self.parallel_config.ulysses_sequence_parallel_size
        assert (self.world_size ==
                tensor_parallel_size * pp_parallel_size * sp_parallel_size), (
            f"world_size ({self.world_size}) must be equal to the "
            f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
            f"_parallel_size ({pp_parallel_size}) x ulysses_sequence_parallel"
            f"_size ({sp_parallel_size}).")

        # Set multiprocessing envs that are common to V0 and V1
        set_multiprocessing_worker_envs(self.parallel_config)

        # Multiprocessing-based executor does not support multi-node setting.
        # Since it only works for single node, we can use the loopback address
        # 127.0.0.1 for communication.
        distributed_init_method = get_distributed_init_method(
            "127.0.0.1", get_open_port())

        # Initialize worker and set up message queues for SchedulerOutputs
        # and ModelRunnerOutputs
        max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
        self.rpc_broadcast_mq = MessageQueue(self.world_size,
                                             self.world_size,
                                             max_chunk_bytes=max_chunk_bytes)
        scheduler_output_handle = self.rpc_broadcast_mq.export_handle()

        # Create workers
        unready_workers: list[UnreadyWorkerProcHandle] = []
        success = False
        try:
            for rank in range(self.world_size):
                unready_workers.append(
                    WorkerProc.make_worker_process(
                        vllm_config=self.vllm_config,
                        local_rank=rank,
                        rank=rank,
                        distributed_init_method=distributed_init_method,
                        input_shm_handle=scheduler_output_handle,
                    ))

            # Workers must be created before wait_for_ready to avoid
            # deadlock, since worker.init_device() does a device sync.
            self.workers = WorkerProc.wait_for_ready(unready_workers)

            # Ensure message queues are ready. Will deadlock if re-ordered
            # Must be kept consistent with the WorkerProc.
            self.rpc_broadcast_mq.wait_until_ready()
            for w in self.workers:
                w.worker_response_mq.wait_until_ready()

            self.start_worker_monitor()
            success = True
        finally:
            if not success:
                # Clean up the worker procs if there was a failure.
                self._ensure_worker_termination(
                    [w.proc for w in unready_workers])

        # For pipeline parallel, we use a thread pool for asynchronous
        # execute_model.
        if self.max_concurrent_batches > 1:
            # Note: must use only 1 IO thread to keep dequeue sequence
            # from the response queue
            self.io_thread_pool = ThreadPoolExecutor(
                max_workers=1, thread_name_prefix="mp_exec_io")

        self.output_rank = self._get_output_rank()


class UlyssesAttentionPatch(ArcticPatch[Attention]):

    _orig_init = Attention.__init__
    _orig_forward = Attention.forward

    def __init__(self, num_heads, *args, **kwargs):
        from .model_runner import is_shift_parallel_mode
        self.sp_size = parallel_state._SP.world_size
        self.sp_device_group = parallel_state._SP.device_group
        if not is_shift_parallel_mode():
            num_heads //= self.sp_size
            num_kv_heads = kwargs["num_kv_heads"]
            self.is_kv_replicated = True if num_kv_heads < self.sp_size else False
            if self.is_kv_replicated:
                num_kv_heads = 1
                assert parallel_state._SP_AA is not None and parallel_state._SP_AG is not None, (
                    "UlyssesAttentionPatch requires SP_AA and SP_AG groups to be initialized.")
                self.sp_aa_device_group = parallel_state._SP_AA.device_group
                self.sp_ag_device_group = parallel_state._SP_AG.device_group
                self.sp_aa_size = parallel_state._SP_AA.world_size
                self.sp_ag_size = parallel_state._SP_AG.world_size
                # this reorders the all-gathered sequence
                self.order = [j * self.sp_aa_size + i 
                              for i in range(self.sp_aa_size) 
                              for j in range(self.sp_ag_size)]
            else:
                num_kv_heads //= self.sp_size
            kwargs["num_kv_heads"] = num_kv_heads
        return self._orig_init(num_heads, *args, **kwargs)

    def forward(self, query, key, value, **kwargs):
        from .model_runner import is_shift_parallel_mode
        if self.sp_size == 1 or is_shift_parallel_mode():
            return self._orig_forward(query, key, value, **kwargs)

        if self.is_kv_replicated:
            # Ulysses all-to-all 1/2 (query)
            q = query.view(-1,
                           self.sp_size, self.num_heads * self.head_size).transpose(
                               0, 1).reshape(-1,
                                             self.num_heads * self.head_size)
            q_ = torch.empty_like(q)
            torch.distributed.all_to_all_single(q_, q, group=self.sp_device_group)
            # Ulysses pack (key, value)
            kv = torch.cat((key.view(-1, self.sp_aa_size, self.num_kv_heads * self.head_size),
                            value.view(-1, self.sp_aa_size, self.num_kv_heads * self.head_size)),
                           dim=-1).transpose(0, 1).reshape(
                               -1, 2 * self.num_kv_heads * self.head_size)
            # Ulysses all-to-all (key, value)
            kv_part = torch.empty_like(kv)
            torch.distributed.all_to_all_single(kv_part, kv, group=self.sp_aa_device_group)
            # Ulysses all-gather (key, value)
            kv_ = torch.empty(q_.shape[0],
                              2 * self.num_kv_heads * self.head_size,
                              dtype=query.dtype,
                              device=query.device)
            torch.distributed.all_gather_into_tensor(kv_,
                                                     kv_part,
                                                     group=self.sp_ag_device_group)
            # reorder
            kv_chunk = kv_.chunk(self.sp_size)
            kv_ordered = torch.cat([kv_chunk[i] for i in self.order])
            # unpack (key, value)
            k_, v_ = kv_ordered.split([self.num_kv_heads * self.head_size] * 2, dim=-1)
        else:
            # pack
            qkv = (torch.cat(
                (query.view(-1, self.sp_size, self.num_heads * self.head_size),
                key.view(-1, self.sp_size, self.num_kv_heads * self.head_size),
                value.view(-1, self.sp_size, self.num_kv_heads * self.head_size)),
                dim=-1)
                .transpose(0, 1)
                .reshape(-1, (self.num_heads + 2 * self.num_kv_heads) * self.head_size))
            # Ulysses all-to-all 1/2
            qkv_ = torch.empty_like(qkv)
            torch.distributed.all_to_all_single(qkv_, qkv, group=self.sp_device_group)
            # unpack
            q_, k_, v_ = qkv_.split([
                self.num_heads * self.head_size, self.num_kv_heads *
                self.head_size, self.num_kv_heads * self.head_size
            ], dim=-1)

        # original attention
        c_ = self._orig_forward(q_, k_, v_, **kwargs)

        # Ulysses all-to-all 2/2
        c = torch.empty_like(c_)
        torch.distributed.all_to_all_single(c, c_, group=self.sp_device_group)
        output = (c.view(self.sp_size, -1, self.num_heads * self.head_size)
                  .transpose(0, 1)
                  .reshape(-1, self.num_heads * self.sp_size * self.head_size))
        
        return output


class PiecewiseCompileInterpreterPatch(ArcticPatch[PiecewiseCompileInterpreter]):

    # find the symbolic shape of the subgraph
    def find_symbolic_shape(self, args: tuple[torch.fx.node.Argument,
                                ...]) -> torch.SymInt:
        symbols = set()
        for x in args:
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor):
                for dim in x.shape:
                    if isinstance(dim, torch.SymInt):
                        symbols.update(dim.node.expr.free_symbols)
        assert len(symbols) == 1, (
            f"Expected exactly one symbolic shape, but found {len(symbols)}: {symbols}")
        return list(symbols)[0]
  
    def call_module(self, target: torch.fx.node.Target,
                    args: tuple[torch.fx.node.Argument,
                                ...], kwargs: dict[str, Any]) -> Any:
        assert isinstance(target, str)
        # [Arctic Inference]
        # Since monkeypatching inherits the original class
        # through ArcticPatch class, we lose the access to the original class'
        # super() function. Instead of using super(), we directly invoke call_module
        # from the super class torch.fx.Interpreter of PiecewiseCompileInterpreter.
        # see - v0.9.0.1/compilation/backends.py#L241
        output = torch.fx.Interpreter.call_module(self, target, args, kwargs)

        if target in self.compile_submod_names:
            index = self.compile_submod_names.index(target)
            submod = self.fetch_attr(target)
            # [Arctic Inference]
            # Compiler may create subgraphs with certain symbolic
            # integer values that violates vllm's assumption here:
            # - v0.9.0.1/compilation/base_piecewise_backend.py#L64
            # The index of the significant symbol determines the runtime shape here:
            # - v0.9.0.1/compilation/cuda_piecewise_backend.py#L112
            # The fix is relaxing vllm's original assumption that there is only a
            # single symbolic that determines the shape.We then find the matching 
            # symbol indices.
            sym_shape = self.find_symbolic_shape(args)
            sym_shape_indices = []
            for i, x in enumerate(args):
                if isinstance(x, torch.SymInt):
                    if sym_shape == x:
                        sym_shape_indices.append(i)

            global compilation_start_time
            compiled_graph_for_general_shape = self.vllm_backend.\
                compiler_manager.compile(
                submod,
                args,
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
                graph_index=index,
                num_graphs=len(self.compile_submod_names),
                runtime_shape=None)

            piecewise_backend = resolve_obj_by_qualname(
                current_platform.get_piecewise_backend_cls())
            self.module.__dict__[target] = piecewise_backend(
                submod, self.vllm_config, self.graph_pool, index,
                len(self.compile_submod_names), sym_shape_indices,
                compiled_graph_for_general_shape, self.vllm_backend)

            from vllm.compilation.counter import compilation_counter
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


class UlyssesFusedMoEPatch(ArcticPatch[FusedMoE]):

    def forward(self, hidden_states: torch.Tensor,
                router_logits: torch.Tensor):
        # directly call forward_impl to bypass custom opt
        # custom opt prevents using the shift model
        # we will expand this function to fuse SP with EP
        return self.forward_impl(hidden_states, router_logits)
