arctic_inference/vllm/spec_dec/arctic_speculator.py (719 lines of code) (raw):

# Copyright 2025 Snowflake Inc. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import math from typing import Iterable, List, Tuple import torch import torch.nn as nn from vllm.config import VllmConfig from arctic_inference.vllm.spec_dec.logits_processor_opt import LogitsProcessorOpt from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from arctic_inference.vllm.spec_dec.fp8 import (Fp8ConfigWithEmbedding, OriginalFp8LinearMethod) from arctic_inference.vllm.spec_dec.vocab_parallel_embedding import ( SpeculatorTPInit, ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader SQRT2 = 2**0.5 def padding_size(size: int) -> int: """Round up a size to the nearest multiple of 4.""" mult = (1 << (size - 1).bit_length()) // 4 if mult < 1: return size return (size + mult - 1) // mult * mult from contextlib import contextmanager @contextmanager def graph_capture(device: torch.device): from vllm.distributed.parallel_state import GraphCaptureContext context = GraphCaptureContext(torch.cuda.Stream(device=device)) import vllm.distributed.parallel_state as parallel_state with parallel_state._TP.graph_capture(context): yield context class MLPSpeculatorLayerNorm(nn.Module): """ A L2 normalization implementation ... Args ---- normalized_shape : int Dimensionality of input data (size of final tensor axis) eps : float Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8). elementwise_scale_and_shift : bool Include a learned scaling and shift term after normalization. """ def __init__( self, normalized_shape, eps=1e-06, elementwise_scale_and_shift=True, ): super().__init__() self.elementwise_scale_and_shift = elementwise_scale_and_shift if self.elementwise_scale_and_shift: self.weight = nn.Parameter(torch.empty(normalized_shape)) self.bias = nn.Parameter(torch.empty(normalized_shape)) self.eps = eps def forward(self, x): xf = x xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) x = xf.type_as(x) if self.elementwise_scale_and_shift: x = self.weight * x x = x + self.bias return x def _generate_cg_key(padding_size: int, head_index: int): return (padding_size << 16) + head_index class ArcticMLPSpeculator(nn.Module, SpeculatorTPInit): """ An implementation of the speculative models introduced in "Accelerating Production LLMs with Combined Token/Embedding Speculators" https://arxiv.org/pdf/2404.19124 Trained speculators of this type are available on HF hub at: https://huggingface.co/ibm-fms and https://huggingface.co/ibm-granite """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() SpeculatorTPInit.__init__(self) config = vllm_config.model_config.hf_config self.n_predict = config.n_predict self.vocab_size = config.vocab_size self.emb_dim = config.emb_dim self.inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim self.max_speculative_tokens = config.num_lookahead_tokens self.tie_weights = config.tie_weights self.scale_input = config.scale_input self.quantize_lm_head = True quant_config = Fp8ConfigWithEmbedding( ) if self.quantize_lm_head else None self.qhead = None if self.tie_weights: assert ( self.n_predict > 1 ), "You cannot tie weights between stages when only 1 exists" embedding = VocabParallelEmbedding( config.vocab_size, self.inner_dim, org_num_embeddings=config.vocab_size) self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens) # the initial projection from the base model may # have a different size, so that stays separate. proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False) proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False) self.proj = nn.ModuleList([proj_first] + [proj_tied] * (self.max_speculative_tokens - 1)) head = ParallelLMHead( self.vocab_size, self.inner_dim, bias=False, quant_config=quant_config, skip_quantization=True, ) self.head = nn.ModuleList([head] * self.max_speculative_tokens) if self.quantize_lm_head: qhead = ParallelLMHead( self.vocab_size, self.inner_dim, bias=False, quant_config=quant_config, skip_quantization=False, ) qhead.quant_method = OriginalFp8LinearMethod( quant_config=quant_config) self.qhead = nn.ModuleList([qhead] * self.max_speculative_tokens) ln = MLPSpeculatorLayerNorm(self.inner_dim, elementwise_scale_and_shift=True) self.ln = nn.ModuleList([ln] * self.max_speculative_tokens) else: self.emb = nn.ModuleList([ VocabParallelEmbedding( config.vocab_size, self.inner_dim, org_num_embeddings=config.vocab_size, ) for _ in range(self.max_speculative_tokens) ]) self.proj = nn.ModuleList([ nn.Linear( (self.emb_dim if i == 0 else self.inner_dim), self.inner_dim, bias=False, ) for i in range(self.max_speculative_tokens) ]) self.head = nn.ModuleList([ ParallelLMHead( self.vocab_size, self.inner_dim, bias=False, quant_config=quant_config, ) for _ in range(self.max_speculative_tokens) ]) self.ln = nn.ModuleList([ MLPSpeculatorLayerNorm(self.inner_dim, elementwise_scale_and_shift=True) for _ in range(self.max_speculative_tokens) ]) if self.scale_input: self.ln0 = MLPSpeculatorLayerNorm( self.emb_dim, elementwise_scale_and_shift=False) self.state_weight = 0.5**(0.5 / config.n_predict) self.emb_weight = math.sqrt( (1 - self.state_weight**2) * (self.inner_dim / 2)) self.activation = nn.GELU() self.config = config self.logits_processor = LogitsProcessorOpt( vocab_size=config.vocab_size, org_vocab_size=config.vocab_size, scale=1.0, skip_last_gather=True, ) self.sampler = get_sampler() self.cuda_graph_max_batch_size = 0 self.cuda_graph_mode = False if not vllm_config.model_config.enforce_eager: self.cuda_graph_mode = True self.cuda_graphs = {} self.cuda_graph_max_batch_size = padding_size( vllm_config.scheduler_config.max_num_seqs) self.static_cuda_buffers = { "last_tokens": torch.empty(self.cuda_graph_max_batch_size, 1, dtype=torch.long), "previous_hidden_states": torch.empty(self.cuda_graph_max_batch_size, 1, self.emb_dim), "next_tokens": [ torch.empty(self.cuda_graph_max_batch_size, 1, dtype=torch.long) for _ in range(self.n_predict) ], } def _prepare_cuda_graph_ios(self, size, last_tokens, previous_hidden_states): self.static_cuda_buffers["last_tokens"][:size] = last_tokens if previous_hidden_states is not None: self.static_cuda_buffers[ "previous_hidden_states"][:size] = previous_hidden_states padded_size = padding_size(size) static_last_tokens = self.static_cuda_buffers[ "last_tokens"][:padded_size] static_hidden_states = self.static_cuda_buffers[ "previous_hidden_states"][:padded_size] return (padded_size, static_last_tokens, static_hidden_states) def generate_states( self, last_tokens: torch.Tensor, previous_hidden_states: torch.Tensor, head_index: int, ) -> torch.Tensor: if head_index == 0 and self.scale_input: previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 # Project and predict z = self.emb[head_index](last_tokens) # b k d states = self.proj[head_index](previous_hidden_states) # Weighted add of state_weight*state and emb_weight*z # Let subsequent LN take care of denominator # state_weight is close to 1, so shouldn't be any precision issues states.add_(z, alpha=self.emb_weight / self.state_weight) states = self.activation(self.ln[head_index](states)) # b k d return states def generate_token_ids( self, batch_size: int, num_predict_tokens: int, last_tokens: torch.Tensor, previous_hidden_states: torch.Tensor, next_tokens_tensors: List[torch.Tensor], ) -> torch.Tensor: for head_index in range(num_predict_tokens): states = self.generate_states(last_tokens, previous_hidden_states, head_index) previous_hidden_states = states states = states.flatten(0, 1) head_weight = (self.qhead[head_index] if self.qhead is not None and batch_size <= 32 else self.head[head_index]) logits = self.logits_processor(head_weight, states) if self.tp_size == 1: last_tokens = torch.argmax(logits, dim=-1).reshape(batch_size, -1) else: vals, indices = torch.topk(logits, 1, dim=-1) indices = indices + self.tp_rank * logits.shape[-1] packed_data = torch.cat( [vals.to(torch.float64).view(torch.int64), indices], dim=0) packed_data = self.TP_GROUP.all_gather(packed_data) vals, indices = packed_data.split(batch_size, dim=0) vals = vals.view(torch.float64) argidx = torch.argmax(vals, -1).reshape(batch_size, -1) last_tokens = torch.gather(indices, -1, argidx) if next_tokens_tensors[head_index] == None: next_tokens_tensors[head_index] = last_tokens else: next_tokens_tensors[head_index].copy_(last_tokens) def generate_proposals( self, input_ids: torch.Tensor, previous_hidden_states: torch.Tensor, num_predict_tokens: int, ) -> List[torch.tensor]: if num_predict_tokens > self.max_speculative_tokens: raise ValueError(f"Max speculative tokens for model is " f"{self.max_speculative_tokens}, but " f"{num_predict_tokens} were requested") # b x 1 x d previous_hidden_states = previous_hidden_states.unsqueeze(1) # b x 1 last_tokens = input_ids.unsqueeze(1) batch_size = input_ids.size(0) static_next_tokens = [None] * num_predict_tokens if self.cuda_graph_mode and batch_size <= self.cuda_graph_max_batch_size: padded_size, static_last_tokens, static_hidden_states = ( self._prepare_cuda_graph_ios(batch_size, last_tokens, previous_hidden_states)) cg_key = _generate_cg_key(padded_size, 0) g = self.cuda_graphs.get(cg_key) for i in range(num_predict_tokens): static_next_tokens[i] = self.static_cuda_buffers[ "next_tokens"][i][:padded_size] if g is None: device = torch.cuda.current_device() with graph_capture(device=device) as capture_context: g = torch.cuda.CUDAGraph() with torch.cuda.graph(g, stream=capture_context.stream): self.generate_token_ids( padded_size, num_predict_tokens, static_last_tokens, static_hidden_states, static_next_tokens, ) self.cuda_graphs[cg_key] = g else: g.replay() else: self.generate_token_ids( batch_size, num_predict_tokens, last_tokens, previous_hidden_states, static_next_tokens, ) next_tokens = [] for i in range(num_predict_tokens): next_tokens.append(static_next_tokens[i][:batch_size]) return torch.cat(next_tokens, dim=-1) def maybe_load_weight(self, param, loaded_weight): if param is not None: weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: name = name.replace("speculator.", "") param = params_dict.get(name) self.maybe_load_weight(param, loaded_weight) if name.startswith("head"): param = params_dict.get(name.replace("head", "qhead")) self.maybe_load_weight(param, loaded_weight) class ArcticLSTMSpeculator(nn.Module, SpeculatorTPInit): """ An implementation of the speculative models introduced in "Accelerating Production LLMs with Combined Token/Embedding Speculators" https://arxiv.org/pdf/2404.19124 Trained speculators of this type are available on HF hub at: https://huggingface.co/ibm-fms and https://huggingface.co/ibm-granite """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() SpeculatorTPInit.__init__(self) config = vllm_config.model_config.hf_config self.n_predict = config.n_predict self.vocab_size = config.vocab_size self.input_hidden_dim = config.input_hidden_dim config.inner_dim = [int(i) for i in config.inner_dim.split(".")] self.inner_dim = config.inner_dim config.emb_dim = [int(i) for i in config.emb_dim.split(".")] self.emb_dim = config.emb_dim config.proj_dim = [int(i) for i in config.proj_dim.split(".")] self.proj_dim = config.proj_dim self.max_speculative_tokens = config.num_lookahead_tokens self.tie_weights = config.tie_weights self.tie_lstm_embs = config.tie_lstm_embs self.scale_input = config.scale_input self.quantize_lm_head = True quant_config = Fp8ConfigWithEmbedding( ) if self.quantize_lm_head else None self.method = getattr(config, "method", "sum_rnn") self.activation = nn.GELU() self.qhead = None if self.tie_weights: head = ParallelLMHead( self.vocab_size, self.inner_dim[-1], bias=False, quant_config=quant_config, skip_quantization=True, ) self.head = nn.ModuleList([head] * self.max_speculative_tokens) if self.quantize_lm_head: qhead = ParallelLMHead( self.vocab_size, self.inner_dim[-1], bias=False, quant_config=quant_config, skip_quantization=False, ) qhead.quant_method = OriginalFp8LinearMethod( quant_config=quant_config) self.qhead = nn.ModuleList([qhead] * self.max_speculative_tokens) else: self.head = nn.ModuleList([ ParallelLMHead( self.vocab_size, self.inner_dim[-1], bias=False, quant_config=quant_config, ) for _ in range(self.max_speculative_tokens) ]) if self.method == "sum_rnn": embs = [] for n_i in range(self.n_predict): if not self.tie_weights or n_i == 0: seqs = [ VocabParallelEmbedding(self.vocab_size, self.emb_dim[0]) ] for i in range(1, len(self.emb_dim)): print(f"ADDING ANOTHER EMB {i}") seqs.append( MLPSpeculatorLayerNorm( self.emb_dim[i], elementwise_scale_and_shift=True)) seqs.append(self.activation) seqs.append( nn.Linear(self.emb_dim[i - 1], self.emb_dim[i], bias=False)) embs.append(nn.Sequential(*seqs)) self.emb = nn.ModuleList(embs) projs = [] for n_i in range(self.n_predict): if not self.tie_weights or n_i <= 1: seqs = [ nn.Linear( (self.input_hidden_dim if n_i == 0 else self.inner_dim[-1]), self.proj_dim[0], bias=False, ) ] for i in range(1, len(self.proj_dim)): print(f"ADDING ANOTHER PROJ {i}") seqs.append( MLPSpeculatorLayerNorm( self.proj_dim[i], elementwise_scale_and_shift=True)) seqs.append(self.activation) seqs.append( nn.Linear(self.proj_dim[i - 1], self.proj_dim[i], bias=False)) projs.append(nn.Sequential(*seqs)) self.proj = nn.ModuleList(projs) lns = [] for n_i in range(self.n_predict): if not self.tie_weights or n_i == 0: seqs = [ MLPSpeculatorLayerNorm( self.inner_dim[0], elementwise_scale_and_shift=True) ] for i in range(1, len(self.inner_dim)): seqs.append(self.activation) seqs.append( nn.Linear(self.inner_dim[i - 1], self.inner_dim[i], bias=False)) seqs.append( MLPSpeculatorLayerNorm( self.inner_dim[i], elementwise_scale_and_shift=True)) lns.append(nn.Sequential(*seqs)) self.ln = nn.ModuleList(lns) elif self.method == "sum_lstm": assert self.tie_weights self.forget_emb = nn.ModuleList( [nn.Embedding(self.vocab_size, self.emb_dim[0])]) if not self.tie_lstm_embs: self.input_emb = nn.ModuleList( [nn.Embedding(self.vocab_size, self.emb_dim[0])]) self.cell_emb = nn.ModuleList( [nn.Embedding(self.vocab_size, self.emb_dim[0])]) self.output_emb = nn.ModuleList( [nn.Embedding(self.vocab_size, self.emb_dim[0])]) self.projs = nn.ModuleList([ nn.Linear(self.input_hidden_dim, self.proj_dim[0] * 4, bias=False), nn.Linear(self.inner_dim[-1], self.proj_dim[0] * 4, bias=False), ]) self.cell_ln = nn.ModuleList([ MLPSpeculatorLayerNorm(self.inner_dim[0], elementwise_scale_and_shift=True) ]) self.state_ln = nn.ModuleList([ MLPSpeculatorLayerNorm(self.inner_dim[0], elementwise_scale_and_shift=True) ]) if self.scale_input: self.ln0 = MLPSpeculatorLayerNorm( self.input_hidden_dim, elementwise_scale_and_shift=False) self.state_weight = 0.5**(0.5 / config.n_predict) self.emb_weight = math.sqrt( (1 - self.state_weight**2) * (self.inner_dim[0] / 2)) self.config = config self.logits_processor = LogitsProcessorOpt( vocab_size=config.vocab_size, org_vocab_size=config.vocab_size, scale=1.0, skip_last_gather=True, ) self.sampler = get_sampler() self.cuda_graph_max_batch_size = 0 self.cuda_graph_mode = False self.cuda_graph_max_batch_size = padding_size( vllm_config.scheduler_config.max_num_seqs) self.static_cuda_buffers = { "last_tokens": torch.empty(self.cuda_graph_max_batch_size, 1, dtype=torch.long), "previous_hidden_states": torch.empty(self.cuda_graph_max_batch_size, 1, self.input_hidden_dim), "cell_states": torch.empty(self.cuda_graph_max_batch_size, 1, self.inner_dim[-1]), "next_tokens": [ torch.empty(self.cuda_graph_max_batch_size, 1, dtype=torch.long) for _ in range(self.n_predict) ], } if self.inner_dim[-1] != self.input_hidden_dim: print("CREATED NEXT PREVIOUS HIDDEN STATES") self.static_cuda_buffers[ "next_previous_hidden_states"] = torch.empty( self.cuda_graph_max_batch_size, 1, self.inner_dim[-1]) if not vllm_config.model_config.enforce_eager: self.cuda_graph_mode = True self.cuda_graphs = {} def _prepare_cuda_graph_ios( self, size, last_tokens, previous_hidden_states, hidden_state_buffers, cell_states=None, use_lstm=False, ): self.static_cuda_buffers["last_tokens"][:size] = last_tokens if cell_states is not None: self.static_cuda_buffers["cell_states"][:size] = cell_states if previous_hidden_states is not None: hidden_state_buffers[:size] = previous_hidden_states padded_size = padding_size(size) if self.cuda_graph_mode else size static_last_tokens = self.static_cuda_buffers[ "last_tokens"][:padded_size] static_hidden_states = hidden_state_buffers[:padded_size] if use_lstm: static_cell_states = self.static_cuda_buffers[ "cell_states"][:padded_size] return ( padded_size, static_last_tokens, static_hidden_states, static_cell_states, ) else: return (padded_size, static_last_tokens, static_hidden_states) def generate_states( self, last_tokens: torch.Tensor, previous_hidden_states: torch.Tensor, head_index: int, cell_states: torch.Tensor = None, ) -> torch.Tensor: if head_index == 0 and self.scale_input: previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 actual_i = 0 if self.tie_weights else head_index actual_proj_i = 1 if self.tie_weights and head_index >= 2 else head_index if self.method == "sum_lstm": assert self.tie_lstm_embs prev_state = previous_hidden_states z = self.forget_emb[actual_i](last_tokens).repeat(1, 1, 4) # b n d states = self.projs[actual_proj_i](prev_state) added_states = torch.add(states, z, alpha=self.emb_weight / self.state_weight) forget_input_output, cell_candidate = added_states.split( [self.proj_dim[0] * 3, self.proj_dim[0]], dim=-1) forget_gate, input_gate, output_gate = torch.sigmoid( forget_input_output).split( [self.proj_dim[0], self.proj_dim[0], self.proj_dim[0]], dim=-1) cell_candidate = self.activation( self.cell_ln[actual_i](cell_candidate)) # b n d cell_candidate = cell_candidate * input_gate cell_states = cell_states * forget_gate cell_states = cell_states + cell_candidate state_candidate = self.activation( self.state_ln[actual_i](cell_states)) state = state_candidate * output_gate return state, cell_states else: # Project and predict z = self.emb[actual_i](last_tokens) # b k d states = self.proj[actual_proj_i](previous_hidden_states) # Weighted add of state_weight*state and emb_weight*z # Let subsequent LN take care of denominator # state_weight is close to 1, so shouldn't be any precision issues states.add_(z, alpha=self.emb_weight / self.state_weight) states = self.activation(self.ln[actual_i](states)) # b k d return states def generate_token_ids( self, batch_size: int, num_predict_tokens: int, last_tokens: torch.Tensor, previous_hidden_states: torch.Tensor, next_tokens_tensors: List[torch.Tensor], cell_states: torch.Tensor = None, ) -> torch.Tensor: for head_index in range(num_predict_tokens): if self.method == "sum_lstm": states, cell_states = self.generate_states( last_tokens, previous_hidden_states, head_index, cell_states) else: states = self.generate_states(last_tokens, previous_hidden_states, head_index) previous_hidden_states = states states = states.flatten(0, 1) head_weight = (self.qhead[head_index] if self.qhead is not None and batch_size <= 32 else self.head[head_index]) logits = self.logits_processor(head_weight, states) if self.tp_size == 1: last_tokens = torch.argmax(logits, dim=-1).reshape(batch_size, -1) else: vals, indices = torch.topk(logits, 1, dim=-1) indices = indices + self.tp_rank * logits.shape[-1] packed_data = torch.cat( [vals.to(torch.float64).view(torch.int64), indices], dim=0) packed_data = self.TP_GROUP.all_gather(packed_data) vals, indices = packed_data.split(batch_size, dim=0) vals = vals.view(torch.float64) argidx = torch.argmax(vals, -1).reshape(batch_size, -1) last_tokens = torch.gather(indices, -1, argidx) if next_tokens_tensors[head_index] == None: next_tokens_tensors[head_index] = last_tokens else: next_tokens_tensors[head_index].copy_(last_tokens) return next_tokens_tensors def generate_proposals( self, input_ids: torch.Tensor, previous_hidden_states: torch.Tensor, num_predict_tokens: int, ) -> List[SamplerOutput]: if num_predict_tokens > self.max_speculative_tokens: raise ValueError(f"Max speculative tokens for model is " f"{self.max_speculative_tokens}, but " f"{num_predict_tokens} were requested") # b x 1 x d previous_hidden_states = previous_hidden_states.unsqueeze(1) # b x 1 last_tokens = input_ids.unsqueeze(1) batch_size = input_ids.size(0) state_shapes = list(previous_hidden_states.shape) state_shapes[-1] = self.inner_dim[-1] static_next_tokens = [None] * num_predict_tokens static_cell_states = None static_last_tokens = None static_hidden_states = None static_states = self.static_cuda_buffers["previous_hidden_states"] if self.method == "sum_lstm": previous_cell_states = torch.zeros( state_shapes, device=previous_hidden_states.device, dtype=previous_hidden_states.dtype, ) ( padded_size, static_last_tokens, static_hidden_states, static_cell_states, ) = self._prepare_cuda_graph_ios( batch_size, last_tokens, previous_hidden_states, static_states, previous_cell_states, use_lstm=True, ) else: padded_size, static_last_tokens, static_hidden_states = ( self._prepare_cuda_graph_ios(batch_size, last_tokens, previous_hidden_states, static_states)) if self.cuda_graph_mode and batch_size <= self.cuda_graph_max_batch_size: cg_key = _generate_cg_key(padded_size, 0) g = self.cuda_graphs.get(cg_key) static_states = ( self.static_cuda_buffers["next_previous_hidden_states"] if self.inner_dim[-1] != self.input_hidden_dim else self.static_cuda_buffers["previous_hidden_states"]) for i in range(num_predict_tokens): static_next_tokens[i] = self.static_cuda_buffers[ "next_tokens"][i][:padded_size] if g is None: device = torch.cuda.current_device() with graph_capture(device=device) as capture_context: g = torch.cuda.CUDAGraph() with torch.cuda.graph(g, stream=capture_context.stream): if self.method == "sum_lstm": self.generate_token_ids( padded_size, num_predict_tokens, static_last_tokens, static_hidden_states, static_next_tokens, cell_states=static_cell_states, ) else: self.generate_token_ids( padded_size, num_predict_tokens, static_last_tokens, static_hidden_states, static_next_tokens, ) self.cuda_graphs[cg_key] = g else: g.replay() else: if self.method == "sum_lstm": self.generate_token_ids( batch_size, num_predict_tokens, static_last_tokens, static_hidden_states, static_next_tokens, cell_states=static_cell_states, ) else: self.generate_token_ids( batch_size, num_predict_tokens, static_last_tokens, static_hidden_states, static_next_tokens, ) next_tokens = [] for i in range(num_predict_tokens): next_tokens.append(static_next_tokens[i][:batch_size]) return torch.cat(next_tokens, dim=-1) def maybe_load_weight(self, param, loaded_weight): if param is not None: weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weights = collections.OrderedDict(weights) if self.method == "sum_lstm" and self.tie_lstm_embs: weights.pop("input_emb.0.weight") weights.pop("cell_emb.0.weight") weights.pop("output_emb.0.weight") for name, param in self.named_parameters(): if "projs." in name: print(f"REPLACING {name}") forget_proj = weights.pop( name.replace("projs", "forget_proj")) input_proj = weights.pop( name.replace("projs", "input_proj")) output_proj = weights.pop( name.replace("projs", "output_proj")) cell_proj = weights.pop(name.replace("projs", "cell_proj")) weights[name] = torch.cat( [forget_proj, input_proj, output_proj, cell_proj]) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights.items(): print(f"LOADING {name}") name = name.replace("speculator.", "") param = params_dict.get(name) self.maybe_load_weight(param, loaded_weight) if name.startswith("head"): param = params_dict.get(name.replace("head", "qhead")) self.maybe_load_weight(param, loaded_weight)