arctic_inference/vllm/model_runner.py (634 lines of code) (raw):
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import copy
import time
from typing import Any, Union, Optional, TYPE_CHECKING
from itertools import tee
import numpy as np
import torch
import vllm.distributed.parallel_state as parallel_state
import vllm.envs as envs
from tqdm import tqdm
from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationLevel
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.parallel_state import (get_pp_group, get_tp_group,
is_global_first_rank)
from vllm.forward_context import set_forward_context
from vllm.config import VllmConfig
from vllm.model_executor.model_loader import get_model
from vllm.sequence import IntermediateTensors
from vllm.utils import round_up
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import MAX_SPEC_LEN, RejectionSampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_model_runner import GPUModelRunner, logger
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from arctic_inference.suffix_decoding import (SuffixDecodingCache,
SuffixDecodingDraft)
from arctic_inference.patching import ArcticPatch
from arctic_inference.vllm.spec_dec.arctic_proposer import ArcticProposer
SP_TP_MODE = None
@contextlib.contextmanager
def set_shift_parallel_mode(mode: Optional[bool]):
if mode is None:
yield
return
global SP_TP_MODE
if not is_shift_parallel_mode():
assert not parallel_state._TP_STATE_PATCHED
parallel_state._ORIG_TP = parallel_state._TP
old_mode = SP_TP_MODE
old_tp_group = parallel_state.get_tp_group()
SP_TP_MODE = mode
parallel_state._TP = (parallel_state._SP_TP
if mode else parallel_state._ORIG_TP)
try:
yield
finally:
# restore the original state
SP_TP_MODE = old_mode
parallel_state._TP = old_tp_group
def is_shift_parallel_mode() -> bool:
"""Check if the shift parallel mode is enabled."""
global SP_TP_MODE
return SP_TP_MODE is True
class GPUModelRunnerPatch(ArcticPatch[GPUModelRunner]):
_orig_initialize_kv_cache = GPUModelRunner.initialize_kv_cache
_orig_prepare_inputs = GPUModelRunner._prepare_inputs
_orig_profile_run = GPUModelRunner.profile_run
_orig_load_model = GPUModelRunner.load_model
_orig_propose_draft_token_ids = GPUModelRunner.propose_draft_token_ids
_orig_init = GPUModelRunner.__init__
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
# Ulysses sequence parallelism
if vllm_config.parallel_config.ulysses_sequence_parallel_size > 1:
self.use_ulysses = True
pass_config = vllm_config.compilation_config.pass_config
if pass_config.enable_sequence_parallelism:
raise ValueError(
"Ulysses sequence parallelism is incompatible with native "
"sequence parallelism. Set enable_sequence_parallelism "
"to False in the pass config to use Ulysses.")
else:
self.use_ulysses = False
# Speculative decoding
# TODO: Use "arctic" as an umbrella method that also covers the Arctic
# Inverence version of "mlp_speculator".
if (vllm_config.speculative_config is not None and \
vllm_config.speculative_config.method in (
"arctic", "suffix", "mlp_speculator")):
# Delay the creation of the drafter until
# after the child class has been initialized.
arctic_speculative_config = vllm_config.speculative_config
vllm_config.speculative_config = None
else:
arctic_speculative_config = None
self._orig_init(vllm_config, device)
# Set up speculative decoding.
self._suffix_cache = None
if arctic_speculative_config is not None:
# Restore the speculative config.
self.vllm_config.speculative_config = arctic_speculative_config
self.speculative_config = arctic_speculative_config
if get_pp_group().is_last_rank:
if (self.speculative_config.method == "arctic"
or self.speculative_config.method == "mlp_speculator"):
self.drafter = ArcticProposer(self.vllm_config)
elif self.speculative_config.method != "suffix":
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
self.rejection_sampler = RejectionSampler()
if (self.speculative_config is not None
and self.speculative_config.enable_suffix_decoding):
if self.speculative_config.method not in ("arctic", "suffix",
"mlp_speculator"):
raise ValueError(
"Suffix decoding is only supported with the 'arctic', "
"'mlp_speculator' or 'suffix' spec decoding methods.")
spec_cfg = self.speculative_config
self._suffix_cache = SuffixDecodingCache(
max_tree_depth=spec_cfg.suffix_cache_max_depth,
max_cached_requests=spec_cfg.suffix_cache_max_requests)
def profile_run(self) -> None:
self._orig_profile_run()
if self.shift_model is not None:
# Run the shift model to trigger compilation.
orig_model, self.model = self.model, self.shift_model
try:
with set_shift_parallel_mode(True):
self._dummy_run(self.max_num_tokens, is_profile=True)
finally:
self.model = orig_model
def _prepare_inputs(self, *args, **kwargs):
attn_metadata, attention_cuda_graphs, logits_indices, *rest = (
self._orig_prepare_inputs(*args, **kwargs))
# SwiftKV requires knowing the logits indices from inside the model
# definition in order to early-stop the prefill tokens.
for meta in attn_metadata.values():
meta.swiftkv_logits_indices = logits_indices
return attn_metadata, attention_cuda_graphs, logits_indices, *rest
def monkeypatch_forward(self: GPUModelRunner):
sp_size = parallel_state._SP.world_size
sp_rank = parallel_state._SP.rank_in_group
device_group = parallel_state._SP.device_group
model_forward = self.model.forward
input_key = 'inputs_embeds' if self.is_multimodal_model else 'input_ids'
def ulysses_forward(*args, **kwargs):
# update inputs
input_tensor = kwargs[input_key]
positions = kwargs['positions']
# Ulysses parameters
N = input_tensor.shape[0]
N_ulysses = N // sp_size
N_offset = N_ulysses * sp_rank
# narrow the input
kwargs[input_key] = input_tensor[N_offset:N_offset + N_ulysses]
kwargs['positions'] = positions[N_offset:N_offset + N_ulysses]
with set_shift_parallel_mode(False):
output = model_forward(*args, **kwargs)
if output.size(0) == N_ulysses:
# all-gather model_output
model_output = torch.empty((N, self.hidden_size),
dtype=output.dtype,
device=output.device)
torch.distributed.all_gather_into_tensor(model_output,
output,
group=device_group)
else:
# SwiftKV models will already have all-gathered the output.
assert output.size(0) == N
model_output = output
return model_output
self.model.forward = ulysses_forward
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
# Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
use_shift_model = (self.use_ulysses and self.shift_model is not None
and num_scheduled_tokens
<= self.shift_parallel_threshold)
if self.use_ulysses and not use_shift_model:
# add padding to the batch size to make it a multiple of SP
sp_size = self.parallel_config.ulysses_sequence_parallel_size
num_input_tokens = round_up(num_scheduled_tokens, sp_size)
if (self.use_cuda_graph and num_input_tokens // sp_size
<= self.cudagraph_batch_sizes[-1]):
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_input_tokens // sp_size) * sp_size
elif (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
mm_embeds = []
if self.is_multimodal_model and get_pp_group().is_first_rank:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens]
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
input_ids, mm_embeds)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
else:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
else:
positions = self.positions[:num_input_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
# Run the model.
# Use persistent buffers for CUDA graphs.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs,
):
self.maybe_setup_kv_connector(scheduler_output)
model = self.shift_model if use_shift_model else self.model
with set_shift_parallel_mode(use_shift_model):
model_output = model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(hidden_states.tensors,
all_gather_group=get_tp_group())
logits = None
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending,
finished_recving)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
num_nans_in_logits = {}
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(logits)
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices = []
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
scheduler_output,
)
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
if not sampled_ids:
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
if self._suffix_cache is not None:
self._update_suffix_cache(valid_sampled_token_ids)
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
else:
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampler_output.sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
self.eplb_step()
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits,
)
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
sampled_token_ids: list[list[int]],
original_sampled_token_ids: np.ndarray,
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata],
attn_metadata: dict[str, Any],
) -> list[list[int]]:
disable_spec_decode = (self.speculative_config and
self.speculative_config.disable_by_batch_size
and len(self.input_batch.req_ids)
> self.speculative_config.disable_by_batch_size)
if disable_spec_decode:
# No speculative decoding is enabled.
return [[] for _ in sampled_token_ids]
suffix_spec_token_ids = None
new_sampled_token_ids = sampled_token_ids.copy()
if self._suffix_cache is not None:
results = self.propose_suffix_draft_token_ids(
new_sampled_token_ids)
suffix_spec_token_ids = []
# The score is an estimate of the acceptance length. Thus, the
# heuristic is to use the suffix decoded tokens if the score is
# greater than the # of tokens we would speculate otherwise.
min_score = (self.speculative_config.num_speculative_tokens
if self.speculative_config.method != "suffix" else 0)
min_score = (0 if self.speculative_config.method == "suffix" else
self.speculative_config.num_speculative_tokens)
for i, result in enumerate(results):
if result.score >= min_score:
# Use suffix decoded tokens, disable other speculation
# methods for this request.
new_sampled_token_ids[i] = []
suffix_spec_token_ids.append(result.token_ids)
else:
suffix_spec_token_ids.append([])
spec_token_ids = None
if self.speculative_config.method == "suffix":
pass
elif (self.speculative_config.method == "arctic"
or self.speculative_config.method == "mlp_speculator"):
assert isinstance(self.drafter, ArcticProposer)
previous_hidden_states = self.drafter.prepare_hidden_states(
sample_hidden_states=sample_hidden_states,
sampled_token_ids=original_sampled_token_ids,
spec_decode_metadata=spec_decode_metadata,
)
spec_token_ids = self.propose_arctic_draft_token_ids(
scheduler_output,
new_sampled_token_ids,
previous_hidden_states=previous_hidden_states)
else:
spec_token_ids = self._orig_propose_draft_token_ids(
scheduler_output,
new_sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
if spec_token_ids is None:
spec_token_ids = suffix_spec_token_ids
elif suffix_spec_token_ids is not None:
spec_token_ids = [
suffix_spec_token_ids[i] or spec_token_ids[i]
for i in range(len(suffix_spec_token_ids))
]
return spec_token_ids
def propose_arctic_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
sampled_token_ids: list[list[int]],
previous_hidden_states: Optional[torch.Tensor] = None,
) -> list[list[int]]:
last_tokens: list[int] = []
max_spec_tokens = self.speculative_config.num_speculative_tokens
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if (num_sampled_ids == 0):
if self.speculative_config.enable_suffix_decoding:
return [[]] * len(sampled_token_ids)
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
sampled_ids = [req_state.get_token_id(seq_len)]
# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
max_spec_tokens = min(
max_spec_tokens,
self.max_model_len - end_idx - 1,
)
if max_spec_tokens <= 0:
continue
self.input_batch.token_ids_cpu[i,
start_idx:end_idx] = sampled_ids[-1]
last_tokens.append(self.input_batch.token_ids_cpu[i, end_idx - 1])
if max_spec_tokens <= 0:
return [[] for _ in sampled_token_ids]
drafter_output = self.drafter.propose(
last_tokens,
previous_hidden_states=previous_hidden_states,
num_predict_tokens=max_spec_tokens,
)
draft_token_ids = drafter_output.tolist()
for i, sampled_ids in enumerate(sampled_token_ids):
if not sampled_ids:
draft_token_ids[i] = []
return draft_token_ids
def _update_suffix_cache(self, sampled_token_ids: list[list[int]]) -> None:
seen_req_ids = set()
for i, sampled_ids in enumerate(sampled_token_ids):
req_id = self.input_batch.req_ids[i]
seen_req_ids.add(req_id)
if not sampled_ids:
continue
index = self.input_batch.req_id_to_index[req_id]
if req_id not in self._suffix_cache.active_requests:
if req_id in self._suffix_cache.cached_requests:
# Reset the suffix cache for this request.
self._suffix_cache.evict_cached_response(req_id)
num_prompt_tokens = self.input_batch.num_prompt_tokens[index]
prompt_token_ids = (
self.input_batch.token_ids_cpu[index, :num_prompt_tokens])
self._suffix_cache.start_request(req_id, prompt_token_ids)
self._suffix_cache.add_active_response(req_id, sampled_ids)
# Stop requests that are not seen
for req_id in list(self._suffix_cache.active_requests):
if req_id not in seen_req_ids:
self._suffix_cache.stop_request(req_id)
def propose_suffix_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
spec_token_ids: Optional[list[list[int]]] = None,
) -> list[list[int]]:
config = self.speculative_config
results = []
for i, sampled_ids in enumerate(sampled_token_ids):
spec_ids = spec_token_ids[i] if spec_token_ids is not None else []
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
results.append(SuffixDecodingDraft())
continue
req_id = self.input_batch.req_ids[i]
# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + len(sampled_ids)
if end_idx >= self.max_model_len:
results.append(SuffixDecodingDraft())
self.input_batch.token_ids_cpu[
i, start_idx:self.
max_model_len] = sampled_ids[:self.max_model_len -
start_idx]
continue
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
size = min(end_idx, config.suffix_cache_max_depth)
pattern = self.input_batch.token_ids_cpu[i, end_idx - size:end_idx]
pattern = pattern.tolist() + spec_ids
if len(pattern) > config.suffix_cache_max_depth:
pattern = pattern[-config.suffix_cache_max_depth:]
max_spec_tokens = min(MAX_SPEC_LEN - len(spec_ids),
config.suffix_cache_max_depth,
self.max_model_len - end_idx - 1)
# max_spec_offset is modified to mimic the behavior of the original
# max_spec_factor and max_spec_offset as if the speculative tokens
# were generated by suffix decoding. For example, if:
# - max_spec_factor = 2
# - max_spec_offset = -1
# - we've already speculated 3 tokens
# - and the suffix match length is 6
# Then:
# - The match length before the already-speculated tokens is 3
# - The original config allow up to 5 speculated tokens total
# - Already speculated 3 tokens, so should allow 2 more tokens
# So the new config should map match length 6 to 2 max spec tokens.
max_spec_factor = config.suffix_max_spec_factor
max_spec_offset = (config.suffix_max_spec_offset - len(spec_ids) *
(max_spec_factor + 1))
result = self._suffix_cache.speculate(
req_id,
pattern,
max_spec_tokens=max_spec_tokens,
max_spec_factor=max_spec_factor,
max_spec_offset=max_spec_offset,
min_token_prob=config.suffix_min_token_prob)
results.append(result)
return results
def load_model(self) -> None:
load_shift_model = (
self.vllm_config.parallel_config.enable_shift_parallel)
if load_shift_model:
# Make a deep copy of the config before loading the model.
shift_config = copy.deepcopy(self.vllm_config)
self._orig_load_model()
if self.parallel_config.ulysses_sequence_parallel_size > 1:
self.monkeypatch_forward()
if load_shift_model:
shift_config.parallel_config.tensor_parallel_size *= (
shift_config.parallel_config.ulysses_sequence_parallel_size)
shift_config.parallel_config.ulysses_sequence_parallel_size = 1
with set_shift_parallel_mode(True):
self.shift_model = get_model(vllm_config=shift_config)
self.shift_parallel_threshold = (
shift_config.parallel_config.shift_parallel_threshold)
if "SwiftKV" in self.model.__class__.__name__:
# HACK: Replace the decode-runner since it always runs in full
# TP, but the original model is captured using SP * BATCH_SIZE,
# which does not cover all its cuda graph sizes. The shift-mode
# model should have all its cuda graphs captured correctly.
self.model.model.decode_runner = (
self.shift_model.model.decode_runner)
else:
self.shift_model = None
self.shift_parallel_threshold = 0
def capture_model(self) -> None:
if not self.use_cuda_graph:
logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"set -O %s and ensure `use_cudagraph` was not manually set to "
"False", CompilationLevel.PIECEWISE)
return
compilation_counter.num_gpu_runner_capture_triggers += 1
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with parallel_state.graph_capture(device=self.device):
sp_size = self.parallel_config.ulysses_sequence_parallel_size
full_cg = self.full_cuda_graph
# capture original model shapes
compilation_cases = (
shape for shape in reversed(self.cudagraph_batch_sizes)
if shape * sp_size > self.shift_parallel_threshold and shape *
sp_size <= self.max_num_tokens)
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
print_cases, compilation_cases = tee(compilation_cases)
logger.info(f"original model shapes {list(print_cases)}")
compilation_cases = tqdm(
list(compilation_cases),
desc="Capturing CUDA graph shapes of original model")
for num_tokens in compilation_cases:
# We skip EPLB here since we don't want to record dummy metrics
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens * sp_size,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
self._dummy_run(num_tokens * sp_size,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
# Capture shift model shapes
if self.shift_model is not None:
orig_model, self.model = self.model, self.shift_model
# Reset compilation cases
compilation_cases = (
shape for shape in reversed(self.cudagraph_batch_sizes)
if shape <= self.shift_parallel_threshold
or "SwiftKV" in self.model.__class__.__name__)
# Note: We want to capture all shapes for the SwiftKV shift model.
# This is necessary since SwiftKV always uses full TP for the decode runner.
# For all other models, we only capture necessary shapes for the SP_TP mode,
# yielding less setup time.
if is_global_first_rank():
print_cases, compilation_cases = tee(compilation_cases)
logger.info(f"shift model shapes {list(print_cases)}")
compilation_cases = tqdm(
list(compilation_cases),
desc="Capturing CUDA graph shapes of shift model")
with set_shift_parallel_mode(True):
for num_tokens in compilation_cases:
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
self.model = orig_model
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
self._orig_initialize_kv_cache(kv_cache_config)
if self.shift_model is not None:
# Bind the KV caches to the shift parallel model.
forward_context = (
self.vllm_config.compilation_config.static_forward_context)
for mod in self.shift_model.modules():
if isinstance(mod, Attention):
mod.kv_cache = forward_context[mod.layer_name].kv_cache