arctic_inference/vllm/args.py (95 lines of code) (raw):

# Copyright 2025 Snowflake Inc. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import argparse from dataclasses import dataclass, fields from vllm.config import ParallelConfig from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.utils import FlexibleArgumentParser from arctic_inference.patching import ArcticPatch from arctic_inference.vllm.config import ArcticParallelConfig @dataclass class ArcticArgs: ulysses_sequence_parallel_size: int = 1 enable_shift_parallel: bool = False shift_parallel_threshold: int = 512 @dataclass class ArcticEngineArgs(EngineArgs, ArcticArgs): pass @dataclass class ArcticAsyncEngineArgs(AsyncEngineArgs, ArcticArgs): pass class EngineArgsPatch(ArcticPatch[EngineArgs]): _orig_post_init = EngineArgs.__post_init__ _orig_add_cli_args = EngineArgs.add_cli_args _orig_from_cli_args = EngineArgs.__dict__["from_cli_args"].__wrapped__ _orig_create_engine_config = EngineArgs.create_engine_config _orig_is_v1_supported_oracle = EngineArgs._is_v1_supported_oracle def __new__(cls, *args, **kwargs): # Override __new__ to return an ArcticEngineArgs instead of an # EngineArgs when creating a new instance of the class. if cls is EngineArgs: return ArcticEngineArgs.__new__(ArcticEngineArgs, *args, **kwargs) return super(EngineArgs, cls).__new__(cls) def __post_init__(self): # Explicitly set the distributed executor backend if ulysses is enabled # since the ulysses parameter is not passed to ParallelConfig.__init__, # which leads to the backend being defaulted incorrectly to "uni". if (self.ulysses_sequence_parallel_size > 1 and self.distributed_executor_backend is None): self.distributed_executor_backend = "mp" self._orig_post_init() @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser = EngineArgsPatch._orig_add_cli_args(parser) arctic_group = parser.add_argument_group( title="Arctic Inference", description="Arctic Inference configuration.", ) arctic_group.add_argument( "--ulysses-sequence-parallel-size", type=int, default=ArcticEngineArgs.ulysses_sequence_parallel_size, help="Number of Ulysses sequence parallel replicas", ) arctic_group.add_argument( "--enable-shift-parallel", action='store_true', help='If True, enable shift parallelism.') arctic_group.add_argument( "--shift-parallel-threshold", type=int, default=ArcticEngineArgs.shift_parallel_threshold, help=("Ulysses sequence parallel if batch size > threshold, " "otherwise tensor parallel across the whole world size"), ) return parser @classmethod def from_cli_args(cls, args: argparse.Namespace): if cls is EngineArgs: return EngineArgsPatch._orig_from_cli_args(ArcticEngineArgs, args) if cls is AsyncEngineArgs: return EngineArgsPatch._orig_from_cli_args(ArcticAsyncEngineArgs, args) return EngineArgsPatch._orig_from_cli_args(cls, args) def create_engine_config(self, *args, **kwargs): if (self.ulysses_sequence_parallel_size > 1 and self.distributed_executor_backend is None): self.distributed_executor_backend = "mp" vllm_config = self._orig_create_engine_config(*args, **kwargs) # Recreate the parallel config with Arctic parameters since they might # not be passed to the parallel config __init__ when first initialized. kwargs = {f.name: getattr(vllm_config.parallel_config, f.name) for f in fields(vllm_config.parallel_config) if f.init} kwargs["ulysses_sequence_parallel_size"] = ( self.ulysses_sequence_parallel_size) kwargs["enable_shift_parallel"] = self.enable_shift_parallel kwargs["shift_parallel_threshold"] = self.shift_parallel_threshold vllm_config.parallel_config = ArcticParallelConfig(**kwargs) return vllm_config def _is_v1_supported_oracle(self, *args, **kwargs): orig_speculative_config = self.speculative_config # Since Arctic Inference is only compatible with v1 and we already # check it earlier, we can just disable this check altogether. if (self.speculative_config is not None and self.speculative_config.get("method") in ("arctic", "suffix")): self.speculative_config = None res = self._orig_is_v1_supported_oracle(*args, **kwargs) self.speculative_config = orig_speculative_config return res class AsyncEngineArgsPatch(ArcticPatch[AsyncEngineArgs]): def __new__(cls, *args, **kwargs): # Override __new__ to return an ArcticAsyncEngineArgs instead of an # AsyncEngineArgs when creating a new instance of the class. if cls is AsyncEngineArgs: return ArcticAsyncEngineArgs.__new__(ArcticAsyncEngineArgs, *args, **kwargs) return super(AsyncEngineArgs, cls).__new__(cls)