arctic_inference/vllm/config.py (85 lines of code) (raw):
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import logging
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from arctic_inference.patching import ArcticPatch
logger = logging.getLogger(__name__)
@dataclass
class ArcticParallelConfig(ParallelConfig):
ulysses_sequence_parallel_size: int = 1
enable_shift_parallel: bool = False
shift_parallel_threshold: int = 512
def __post_init__(self, *args, **kwargs):
if (self.enable_shift_parallel
and self.ulysses_sequence_parallel_size == 1):
raise ValueError("ulysses_sequence_parallel_size must be > 1 "
"when enable_shift_parallel is True.")
super().__post_init__(*args, **kwargs)
@property
def world_size(self) -> int:
return (self.pipeline_parallel_size * self.tensor_parallel_size *
self.ulysses_sequence_parallel_size)
@world_size.setter
def world_size(self, value: int) -> None:
# ParallelConfig.__post_init__ will assign world_size to PP * TP, while
# we want PP * TP * SP to be the world size. So we define world_size as
# a property with a no-op setter to ignore the value later assigned by
# ParallelConfig.__post_init__.
pass
@dataclass
class ArcticSpeculativeConfig(SpeculativeConfig):
enable_suffix_decoding: bool = False
suffix_cache_max_depth: int = 64
suffix_speculative_tokens: int = 0
suffix_cache_max_requests: int = 100000
suffix_max_spec_factor: float = 1.0
suffix_max_spec_offset: float = 0.0
suffix_min_token_prob: float = 0.1
class ParallelConfigPatch(ArcticPatch[ParallelConfig]):
def __new__(cls, *args, **kwargs):
# Override __new__ to return an ArcticParallelConfig instead of a
# ParallelConfig when creating a new instance of the class.
if cls is ParallelConfig:
return ArcticParallelConfig.__new__(ArcticParallelConfig, *args,
**kwargs)
return super(ParallelConfig, cls).__new__(cls)
class SpeculativeConfigPatch(ArcticPatch[SpeculativeConfig]):
_orig_from_dict = SpeculativeConfig.__dict__["from_dict"].__wrapped__
_orig_post_init = SpeculativeConfig.__post_init__
def __new__(cls, *args, **kwargs):
# Override __new__ to return an ArcticSpeculativeConfig instead of a
# SpeculativeConfig when creating a new instance of the class.
if cls is SpeculativeConfig:
return ArcticSpeculativeConfig.__new__(ArcticSpeculativeConfig,
*args, **kwargs)
return super(SpeculativeConfig, cls).__new__(cls)
def __post_init__(self):
use_suffix = (self.method
== "suffix") or (self.method is None
and self.enable_suffix_decoding)
use_hybrid = (self.method == "arctic"
and self.enable_suffix_decoding)
if (use_suffix or self.method == "arctic") and \
self.disable_by_batch_size is None:
logger.info("Defaulting disable_by_batch_size to 64")
self.disable_by_batch_size = 64
if use_hybrid:
self.suffix_speculative_tokens = self.suffix_cache_max_depth
if use_suffix:
self.method = "suffix"
self.enable_suffix_decoding = True
self.num_speculative_tokens = self.suffix_cache_max_depth
self._verify_args()
else:
self._orig_post_init()
@classmethod
def from_dict(cls, dict_value: dict) -> SpeculativeConfig:
"""Parse the CLI value for the speculative config."""
if cls is SpeculativeConfig:
return SpeculativeConfigPatch._orig_from_dict(
ArcticSpeculativeConfig, dict_value)
return SpeculativeConfigPatch._orig_from_dict(cls, dict_value)
class VllmConfigPatch(ArcticPatch[VllmConfig]):
_orig_str = VllmConfig.__str__
def __str__(self, *args, **kwargs):
string = self._orig_str(*args, **kwargs)
string += f", ulysses_sequence_parallel_size={self.parallel_config.ulysses_sequence_parallel_size}"
string += f", enable_shift_parallel={self.parallel_config.enable_shift_parallel}"
string += f", shift_parallel_threshold={self.parallel_config.shift_parallel_threshold}"
return string
class MLPSpeculatorConfigPatch(ArcticPatch[MLPSpeculatorConfig]):
_orig_init = MLPSpeculatorConfig.__init__
def __init__(self, *args, **kwargs):
self.base_model_arch = kwargs.pop("base_model_arch", "")
self._orig_init(*args, **kwargs)