arctic_inference/vllm/spec_dec/arctic_proposer.py (116 lines of code) (raw):

# Copyright 2025 Snowflake Inc. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Union from vllm.config import VllmConfig from vllm.model_executor.model_loader import get_model from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.worker.gpu_model_runner import logger import numpy as np import torch from arctic_inference.vllm.spec_dec.arctic_speculator import ArcticMLPSpeculator, ArcticLSTMSpeculator from arctic_inference.envs import ARCTIC_INFERENCE_SKIP_SPEC_MODEL_CHECK class ArcticProposer: def __init__( self, vllm_config: VllmConfig, ): self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config self.model = None self.device = None def load_model( self, model: Union[ArcticMLPSpeculator, ArcticLSTMSpeculator], ): from vllm.config import VllmConfig draft_config_model_config = self.speculative_config.draft_model_config spec_model_archs = draft_config_model_config.hf_config.architectures if not isinstance(spec_model_archs, list): logger.error( f"Draft model architectures {spec_model_archs} is not a list. " ) raise TypeError() if len(spec_model_archs) != 1: logger.error( f"Draft model architectures {spec_model_archs} does not have exactly one architecture. " ) raise ValueError() if spec_model_archs[0] not in [ "ArcticMLPSpeculatorPreTrainedModel", "ArcticLSTMSpeculatorPreTrainedModel", "MLPVariantSpeculatorPreTrainedModel", ]: logger.error( f"Draft model architecture {spec_model_archs} is not supported by Arctic Speculator. " ) raise ValueError() if not ARCTIC_INFERENCE_SKIP_SPEC_MODEL_CHECK: base_model_arch = self.vllm_config.model_config.architectures[0] if not hasattr(draft_config_model_config.hf_config, "base_model_archs"): logger.error( "Draft model config does not have base_model_archs attribute. " "Set ARCTIC_INFERENCE_SKIP_SPEC_MODEL_CHECK=1 to skip this assertion." ) assert False base_model_archs_in_spec_config = draft_config_model_config.hf_config.base_model_archs if base_model_arch not in base_model_archs_in_spec_config: logger.error( f"Draft model trained with base model architectures {base_model_archs_in_spec_config} " f"does not match the base model architecture {base_model_arch} in the vLLM config. " "Set ARCTIC_INFERENCE_SKIP_SPEC_MODEL_CHECK=1 to skip this assertion." ) assert False draft_config_quant_config = VllmConfig._get_quantization_config( self.vllm_config.model_config, self.vllm_config.load_config, ) self.speculative_config.draft_parallel_config.worker_cls =\ self.vllm_config.parallel_config.sd_worker_cls draft_config_parallel_config = self.speculative_config.draft_parallel_config # We cannot use deepcopy here because Ulysses introduces # torch._C._distributed_c10d.ProcessGroup objects that are not # designed to be pickled. draft_worker_config = VllmConfig( model_config=draft_config_model_config, quant_config=draft_config_quant_config, parallel_config=draft_config_parallel_config, load_config=self.vllm_config.load_config, device_config=self.vllm_config.device_config, ) self.model = get_model(vllm_config=draft_worker_config) self.device = next(model.parameters()).device self.input_hidden_dim = self.model.input_hidden_dim if isinstance( self.model, ArcticLSTMSpeculator) else self.model.emb_dim def prepare_hidden_states( self, sample_hidden_states: torch.Tensor, sampled_token_ids: Union[np.ndarray, list[list[int]]], spec_decode_metadata: SpecDecodeMetadata, ) -> torch.Tensor: if sample_hidden_states is not None: assert sample_hidden_states.shape[-1] == self.input_hidden_dim, \ f"hidden_states shape mismatch: {sample_hidden_states.shape[-1]} != {self.input_hidden_dim}. \ Please make sure spec model is trained using the same base model." # if isinstance(sampled_token_ids, list): # # Pad the list of lists to create a uniform tensor # max_len = max(len(x) for x in sampled_token_ids) if sampled_token_ids else 0 # if max_len == 0: # return sample_hidden_states # padded_ids = [l + [-1] * (max_len - len(l)) for l in sampled_token_ids] # sampled_token_ids = torch.tensor(padded_ids, # device=sample_hidden_states.device) max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: return sample_hidden_states assert spec_decode_metadata is not None valid_mask = sampled_token_ids != -1 gen_lens = valid_mask.sum(dim=1) num_sampled_tokens = np.array(spec_decode_metadata.num_draft_tokens) num_sampled_tokens = torch.tensor(num_sampled_tokens, device=gen_lens.device) + 1 hidden_states_idx = (gen_lens - 1) + torch.cumsum( num_sampled_tokens, 0) - num_sampled_tokens previous_hidden_states = sample_hidden_states[hidden_states_idx] return previous_hidden_states def propose( self, context_token_ids: np.ndarray, previous_hidden_states: torch.Tensor, num_predict_tokens: int, ) -> Optional[np.ndarray]: assert num_predict_tokens > 0, \ f"num_predict_tokens must be greater than 0, got {num_predict_tokens}." input_ids = torch.tensor(context_token_ids, device=self.device) next_tokens = self.model.generate_proposals( input_ids=input_ids, previous_hidden_states=previous_hidden_states, num_predict_tokens=num_predict_tokens, ) return next_tokens.cpu().numpy()