arctic_inference/vllm/ulysses.py (401 lines of code) (raw):

# Copyright 2025 Snowflake Inc. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading import weakref from contextlib import contextmanager from concurrent.futures import ThreadPoolExecutor from typing import Optional, Any import torch import vllm.distributed.parallel_state as parallel_state import vllm.envs as envs from vllm.attention.layer import Attention from vllm.config import ModelConfig, ParallelConfig from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.distributed.parallel_state import (init_model_parallel_group, get_world_group, destroy_model_parallel, destroy_distributed_environment) from vllm.executor.multiproc_worker_utils import ( set_multiprocessing_worker_envs) from vllm.utils import get_distributed_init_method, get_open_port from vllm.v1.executor.abstract import FailureCallback from vllm.v1.executor.multiproc_executor import (MultiprocExecutor, WorkerProc, UnreadyWorkerProcHandle) from vllm.platforms import current_platform from vllm.utils import resolve_obj_by_qualname from vllm.compilation.backends import PiecewiseCompileInterpreter from vllm.model_executor.layers.fused_moe import FusedMoE from arctic_inference.patching import ArcticPatch def apply_shift_parallel_patches(): UlyssesModelConfigPatch.apply_patch() UlyssesParallelStatePatch.apply_patch() UlyssesWorkerProcPatch.apply_patch() UlyssesMultiprocExecutorPatch.apply_patch() UlyssesAttentionPatch.apply_patch() PiecewiseCompileInterpreterPatch.apply_patch() UlyssesFusedMoEPatch.apply_patch() class UlyssesModelConfigPatch(ArcticPatch[ModelConfig]): _orig_get_num_kv_heads = ModelConfig.get_num_kv_heads _orig_get_num_attention_heads = ModelConfig.get_num_attention_heads def get_num_kv_heads(self: ModelConfig, parallel_config: ParallelConfig) -> int: num_kv_heads = self._orig_get_num_kv_heads(parallel_config) sp_size = parallel_config.ulysses_sequence_parallel_size return max(1, num_kv_heads // sp_size) def get_num_attention_heads(self: ModelConfig, parallel_config: ParallelConfig) -> int: num_heads = self._orig_get_num_attention_heads(parallel_config) sp_size = parallel_config.ulysses_sequence_parallel_size return max(1, num_heads // sp_size) def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> tuple[int, int]: from vllm.distributed.utils import get_pp_indices if (self.hf_text_config.model_type == "deepseek_mtp" or self.hf_config.model_type == "mimo_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) else: total_num_hidden_layers = getattr(self.hf_text_config, "num_hidden_layers", 0) # the layout order is: DP x PP x SP x TP pp_rank = (parallel_config.rank // (parallel_config.tensor_parallel_size * parallel_config.ulysses_sequence_parallel_size) ) % parallel_config.pipeline_parallel_size pp_size = parallel_config.pipeline_parallel_size start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) return start, end class UlyssesParallelStatePatch(ArcticPatch[parallel_state]): _SP = None _SP_TP = None _SP_AA = None _SP_AG = None # Rationale for SP_AA and SP_AG groups: # When num_kv_heads > SP, the kv heads are distributed and replicated as in TP. # To implement the logic, the distributed kv heads are exchanged with a local # all-to-all within SP_AA group followed by an local all-gather within SP_AG # group. The SP_AA and SP_AG groups partitions the SP group into two orthogonal # sub-groups and will not be initialized if max(1, num_kv_heads / TP) < SP. # See the figure in PR #126 https://github.com/snowflakedb/ArcticInference/pull/126 def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, backend: Optional[str] = None, ) -> None: """ Initialize model parallel groups. Arguments: tensor_model_parallel_size: number of GPUs used for tensor model parallelism. pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize the model pipeline. The present function will create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: 4 tensor model-parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7] 2 pipeline model-parallel groups: [g0, g2, g4, g6], [g1, g3, g5, g7] Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ from vllm.distributed.parallel_state import _DP, _EP, _PP, _TP # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() rank = torch.distributed.get_rank() backend = backend or torch.distributed.get_backend( get_world_group().device_group) data_parallel_size = 1 from vllm.config import get_current_vllm_config config = get_current_vllm_config() if config is not None: data_parallel_size = config.parallel_config.data_parallel_size sequence_parallel_size = \ config.parallel_config.ulysses_sequence_parallel_size # the layout order is: ExternalDP x DP x PP x SP x TP # ExternalDP is the data parallel group that is not part of the model, # every dp rank can generate independently (in verl integration). # DP is the data parallel group that is part of the model, # all the ranks in the same DP group should generate simultaneously, # i.e. the `generate` call in the same DP group should be called together, # otherwise it will cause deadlock. # to get group_ranks for each dimension, transpose that dimension to the # last dimension, then reshape to 2D, then unbind the last dimension all_ranks = torch.arange(world_size).reshape( -1, data_parallel_size, pipeline_model_parallel_size, sequence_parallel_size, tensor_model_parallel_size) # noqa # Build the tensor model-parallel groups. assert _TP is None, ("tensor model parallel group is already initialized") group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] TP_group_ranks = group_ranks # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_message_queue_broadcaster=True, group_name="tp") # Build the pipeline model-parallel groups. assert _PP is None, ( "pipeline model parallel group is already initialized") group_ranks = all_ranks.transpose(2, 4).reshape( -1, pipeline_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] PP_group_ranks = group_ranks _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="pp") assert _DP is None, ("data parallel group is already initialized") group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] DP_group_ranks = group_ranks _DP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="dp") assert _EP is None, ("expert parallel group is already initialized") group_ranks = all_ranks.transpose(1, 3).reshape( -1, data_parallel_size * tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] EP_group_ranks = group_ranks _EP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="ep") # Build the sequence parallel groups. assert parallel_state._SP is None, ( "sequence parallel group is already initialized") group_ranks = all_ranks.transpose(3, 4).reshape( -1, sequence_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] SP_group_ranks = group_ranks _SP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="sp") # Build full-TP groups for ShiftParallel shift_parallel_size = (tensor_model_parallel_size * sequence_parallel_size) assert parallel_state._SP_TP is None, ( "full-TP group is already initialized") # transpose(3, 4) for obtaining the correct attn head order group_ranks = all_ranks.transpose(3, 4).reshape( -1, shift_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] SP_TP_group_ranks = group_ranks _SP_TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="sp_tp") parallel_state.logger.info( "rank %s in world size %s is assigned as DP rank %s, PP rank %s, " "TP rank %s, EP rank %s, SP rank %s, SP_TP rank %s", rank, world_size, _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, _EP.rank_in_group, _SP.rank_in_group, _SP_TP.rank_in_group) parallel_state._TP = _TP parallel_state._PP = _PP parallel_state._SP = _SP parallel_state._SP_TP = _SP_TP parallel_state._DP = _DP # check if SP requires kv replication num_kv_heads = config.model_config._orig_get_num_kv_heads(config.parallel_config) if num_kv_heads < sequence_parallel_size: # divide SP group into two orthogonal sub-groups: sp_aa_size = num_kv_heads sp_ag_size = sequence_parallel_size // num_kv_heads all_ranks_ = torch.arange(world_size).reshape( -1, data_parallel_size, pipeline_model_parallel_size, sp_aa_size, sp_ag_size, tensor_model_parallel_size) group_ranks = all_ranks_.transpose(3, 5).reshape( -1, sp_aa_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] SP_AA_group_ranks = group_ranks # SP_AA group is used for all-to-all communication of kv heads _SP_AA = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="sp_aa") group_ranks = all_ranks_.transpose(4, 5).reshape( -1, sp_ag_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] SP_AG_group_ranks = group_ranks # SP_AG group is used for all-gather communication of kv heads _SP_AG = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="sp_ag") parallel_state._SP_AA = _SP_AA parallel_state._SP_AG = _SP_AG if get_world_group().local_rank == 0: parallel_state.logger.info( f"UlyssesParallelStatePatch initialized:\n" f" PP {_PP.world_size} ranks {PP_group_ranks}\n" f" TP {_TP.world_size} ranks {TP_group_ranks}\n" f" SP {_SP.world_size} ranks {SP_group_ranks}\n" f" DP {_DP.world_size} ranks {DP_group_ranks}\n" f" EP {_EP.world_size} ranks {EP_group_ranks}\n" f" SP_TP {_SP_TP.world_size} ranks {SP_TP_group_ranks}") if num_kv_heads < sequence_parallel_size: parallel_state.logger.info( f" SP_AA {parallel_state._SP_AA.world_size} ranks {SP_AA_group_ranks}\n" f" SP_AG {parallel_state._SP_AG.world_size} ranks {SP_AG_group_ranks}\n") @contextmanager def graph_capture(device: torch.device): """ `graph_capture` is a context manager which should surround the code that is capturing the CUDA graph. Its main purpose is to ensure that the some operations will be run after the graph is captured, before the graph is replayed. It returns a `GraphCaptureContext` object which contains the necessary data for the graph capture. Currently, it only contains the stream that the graph capture is running on. This stream is set to the current CUDA stream when the context manager is entered and reset to the default stream when the context manager is exited. This is to ensure that the graph capture is running on a separate stream from the default stream, in order to explicitly distinguish the kernels to capture from other kernels possibly launched on background in the default stream. """ from vllm.distributed.parallel_state import GraphCaptureContext context = GraphCaptureContext(torch.cuda.Stream(device=device)) with parallel_state._TP.graph_capture(context), parallel_state._PP.graph_capture( context), parallel_state._SP_TP.graph_capture(context): yield context class UlyssesWorkerProcPatch(ArcticPatch[WorkerProc]): def destroy_model_parallel(self): from vllm.distributed.parallel_state import _SP, _SP_TP, _SP_AA, _SP_AG if _SP: _SP.destroy() _SP = None if _SP_TP: _SP_TP.destroy() _SP_TP = None if _SP_AA: _SP_AA.destroy() _SP_AA = None if _SP_AG: _SP_AG.destroy() _SP_AG = None def shutdown(self): self.rpc_broadcast_mq = None self.worker_response_mq = None destroy_model_parallel() # destroy Ulysses communicators here self.destroy_model_parallel() destroy_distributed_environment() class UlyssesMultiprocExecutorPatch(ArcticPatch[MultiprocExecutor]): def _init_executor(self) -> None: # Call self.shutdown at exit to clean up # and ensure workers will be terminated. self._finalizer = weakref.finalize(self, self.shutdown) self.is_failed = False self.shutdown_event = threading.Event() self.failure_callback: Optional[FailureCallback] = None self.io_thread_pool: Optional[ThreadPoolExecutor] = None self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size pp_parallel_size = self.parallel_config.pipeline_parallel_size sp_parallel_size = self.parallel_config.ulysses_sequence_parallel_size assert (self.world_size == tensor_parallel_size * pp_parallel_size * sp_parallel_size), ( f"world_size ({self.world_size}) must be equal to the " f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" f"_parallel_size ({pp_parallel_size}) x ulysses_sequence_parallel" f"_size ({sp_parallel_size}).") # Set multiprocessing envs that are common to V0 and V1 set_multiprocessing_worker_envs(self.parallel_config) # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. distributed_init_method = get_distributed_init_method( "127.0.0.1", get_open_port()) # Initialize worker and set up message queues for SchedulerOutputs # and ModelRunnerOutputs max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 self.rpc_broadcast_mq = MessageQueue(self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes) scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers unready_workers: list[UnreadyWorkerProcHandle] = [] success = False try: for rank in range(self.world_size): unready_workers.append( WorkerProc.make_worker_process( vllm_config=self.vllm_config, local_rank=rank, rank=rank, distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, )) # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. self.workers = WorkerProc.wait_for_ready(unready_workers) # Ensure message queues are ready. Will deadlock if re-ordered # Must be kept consistent with the WorkerProc. self.rpc_broadcast_mq.wait_until_ready() for w in self.workers: w.worker_response_mq.wait_until_ready() self.start_worker_monitor() success = True finally: if not success: # Clean up the worker procs if there was a failure. self._ensure_worker_termination( [w.proc for w in unready_workers]) # For pipeline parallel, we use a thread pool for asynchronous # execute_model. if self.max_concurrent_batches > 1: # Note: must use only 1 IO thread to keep dequeue sequence # from the response queue self.io_thread_pool = ThreadPoolExecutor( max_workers=1, thread_name_prefix="mp_exec_io") self.output_rank = self._get_output_rank() class UlyssesAttentionPatch(ArcticPatch[Attention]): _orig_init = Attention.__init__ _orig_forward = Attention.forward def __init__(self, num_heads, *args, **kwargs): from .model_runner import is_shift_parallel_mode self.sp_size = parallel_state._SP.world_size self.sp_device_group = parallel_state._SP.device_group if not is_shift_parallel_mode(): num_heads //= self.sp_size num_kv_heads = kwargs["num_kv_heads"] self.is_kv_replicated = True if num_kv_heads < self.sp_size else False if self.is_kv_replicated: num_kv_heads = 1 assert parallel_state._SP_AA is not None and parallel_state._SP_AG is not None, ( "UlyssesAttentionPatch requires SP_AA and SP_AG groups to be initialized.") self.sp_aa_device_group = parallel_state._SP_AA.device_group self.sp_ag_device_group = parallel_state._SP_AG.device_group self.sp_aa_size = parallel_state._SP_AA.world_size self.sp_ag_size = parallel_state._SP_AG.world_size # this reorders the all-gathered sequence self.order = [j * self.sp_aa_size + i for i in range(self.sp_aa_size) for j in range(self.sp_ag_size)] else: num_kv_heads //= self.sp_size kwargs["num_kv_heads"] = num_kv_heads return self._orig_init(num_heads, *args, **kwargs) def forward(self, query, key, value, **kwargs): from .model_runner import is_shift_parallel_mode if self.sp_size == 1 or is_shift_parallel_mode(): return self._orig_forward(query, key, value, **kwargs) if self.is_kv_replicated: # Ulysses all-to-all 1/2 (query) q = query.view(-1, self.sp_size, self.num_heads * self.head_size).transpose( 0, 1).reshape(-1, self.num_heads * self.head_size) q_ = torch.empty_like(q) torch.distributed.all_to_all_single(q_, q, group=self.sp_device_group) # Ulysses pack (key, value) kv = torch.cat((key.view(-1, self.sp_aa_size, self.num_kv_heads * self.head_size), value.view(-1, self.sp_aa_size, self.num_kv_heads * self.head_size)), dim=-1).transpose(0, 1).reshape( -1, 2 * self.num_kv_heads * self.head_size) # Ulysses all-to-all (key, value) kv_part = torch.empty_like(kv) torch.distributed.all_to_all_single(kv_part, kv, group=self.sp_aa_device_group) # Ulysses all-gather (key, value) kv_ = torch.empty(q_.shape[0], 2 * self.num_kv_heads * self.head_size, dtype=query.dtype, device=query.device) torch.distributed.all_gather_into_tensor(kv_, kv_part, group=self.sp_ag_device_group) # reorder kv_chunk = kv_.chunk(self.sp_size) kv_ordered = torch.cat([kv_chunk[i] for i in self.order]) # unpack (key, value) k_, v_ = kv_ordered.split([self.num_kv_heads * self.head_size] * 2, dim=-1) else: # pack qkv = (torch.cat( (query.view(-1, self.sp_size, self.num_heads * self.head_size), key.view(-1, self.sp_size, self.num_kv_heads * self.head_size), value.view(-1, self.sp_size, self.num_kv_heads * self.head_size)), dim=-1) .transpose(0, 1) .reshape(-1, (self.num_heads + 2 * self.num_kv_heads) * self.head_size)) # Ulysses all-to-all 1/2 qkv_ = torch.empty_like(qkv) torch.distributed.all_to_all_single(qkv_, qkv, group=self.sp_device_group) # unpack q_, k_, v_ = qkv_.split([ self.num_heads * self.head_size, self.num_kv_heads * self.head_size, self.num_kv_heads * self.head_size ], dim=-1) # original attention c_ = self._orig_forward(q_, k_, v_, **kwargs) # Ulysses all-to-all 2/2 c = torch.empty_like(c_) torch.distributed.all_to_all_single(c, c_, group=self.sp_device_group) output = (c.view(self.sp_size, -1, self.num_heads * self.head_size) .transpose(0, 1) .reshape(-1, self.num_heads * self.sp_size * self.head_size)) return output class PiecewiseCompileInterpreterPatch(ArcticPatch[PiecewiseCompileInterpreter]): # find the symbolic shape of the subgraph def find_symbolic_shape(self, args: tuple[torch.fx.node.Argument, ...]) -> torch.SymInt: symbols = set() for x in args: if isinstance(x, torch._subclasses.fake_tensor.FakeTensor): for dim in x.shape: if isinstance(dim, torch.SymInt): symbols.update(dim.node.expr.free_symbols) assert len(symbols) == 1, ( f"Expected exactly one symbolic shape, but found {len(symbols)}: {symbols}") return list(symbols)[0] def call_module(self, target: torch.fx.node.Target, args: tuple[torch.fx.node.Argument, ...], kwargs: dict[str, Any]) -> Any: assert isinstance(target, str) # [Arctic Inference] # Since monkeypatching inherits the original class # through ArcticPatch class, we lose the access to the original class' # super() function. Instead of using super(), we directly invoke call_module # from the super class torch.fx.Interpreter of PiecewiseCompileInterpreter. # see - v0.9.0.1/compilation/backends.py#L241 output = torch.fx.Interpreter.call_module(self, target, args, kwargs) if target in self.compile_submod_names: index = self.compile_submod_names.index(target) submod = self.fetch_attr(target) # [Arctic Inference] # Compiler may create subgraphs with certain symbolic # integer values that violates vllm's assumption here: # - v0.9.0.1/compilation/base_piecewise_backend.py#L64 # The index of the significant symbol determines the runtime shape here: # - v0.9.0.1/compilation/cuda_piecewise_backend.py#L112 # The fix is relaxing vllm's original assumption that there is only a # single symbolic that determines the shape.We then find the matching # symbol indices. sym_shape = self.find_symbolic_shape(args) sym_shape_indices = [] for i, x in enumerate(args): if isinstance(x, torch.SymInt): if sym_shape == x: sym_shape_indices.append(i) global compilation_start_time compiled_graph_for_general_shape = self.vllm_backend.\ compiler_manager.compile( submod, args, self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=index, num_graphs=len(self.compile_submod_names), runtime_shape=None) piecewise_backend = resolve_obj_by_qualname( current_platform.get_piecewise_backend_cls()) self.module.__dict__[target] = piecewise_backend( submod, self.vllm_config, self.graph_pool, index, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_general_shape, self.vllm_backend) from vllm.compilation.counter import compilation_counter compilation_counter.num_piecewise_capturable_graphs_seen += 1 return output class UlyssesFusedMoEPatch(ArcticPatch[FusedMoE]): def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): # directly call forward_impl to bypass custom opt # custom opt prevents using the shift model # we will expand this function to fuse SP with EP return self.forward_impl(hidden_states, router_logits)