in arctic_inference/vllm/spec_dec/arctic_speculator.py [0:0]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
SpeculatorTPInit.__init__(self)
config = vllm_config.model_config.hf_config
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim
self.max_speculative_tokens = config.num_lookahead_tokens
self.tie_weights = config.tie_weights
self.scale_input = config.scale_input
self.quantize_lm_head = True
quant_config = Fp8ConfigWithEmbedding(
) if self.quantize_lm_head else None
self.qhead = None
if self.tie_weights:
assert (
self.n_predict > 1
), "You cannot tie weights between stages when only 1 exists"
embedding = VocabParallelEmbedding(
config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
# the initial projection from the base model may
# have a different size, so that stays separate.
proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
(self.max_speculative_tokens - 1))
head = ParallelLMHead(
self.vocab_size,
self.inner_dim,
bias=False,
quant_config=quant_config,
skip_quantization=True,
)
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
if self.quantize_lm_head:
qhead = ParallelLMHead(
self.vocab_size,
self.inner_dim,
bias=False,
quant_config=quant_config,
skip_quantization=False,
)
qhead.quant_method = OriginalFp8LinearMethod(
quant_config=quant_config)
self.qhead = nn.ModuleList([qhead] *
self.max_speculative_tokens)
ln = MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True)
self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
else:
self.emb = nn.ModuleList([
VocabParallelEmbedding(
config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size,
) for _ in range(self.max_speculative_tokens)
])
self.proj = nn.ModuleList([
nn.Linear(
(self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim,
bias=False,
) for i in range(self.max_speculative_tokens)
])
self.head = nn.ModuleList([
ParallelLMHead(
self.vocab_size,
self.inner_dim,
bias=False,
quant_config=quant_config,
) for _ in range(self.max_speculative_tokens)
])
self.ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True)
for _ in range(self.max_speculative_tokens)
])
if self.scale_input:
self.ln0 = MLPSpeculatorLayerNorm(
self.emb_dim, elementwise_scale_and_shift=False)
self.state_weight = 0.5**(0.5 / config.n_predict)
self.emb_weight = math.sqrt(
(1 - self.state_weight**2) * (self.inner_dim / 2))
self.activation = nn.GELU()
self.config = config
self.logits_processor = LogitsProcessorOpt(
vocab_size=config.vocab_size,
org_vocab_size=config.vocab_size,
scale=1.0,
skip_last_gather=True,
)
self.sampler = get_sampler()
self.cuda_graph_max_batch_size = 0
self.cuda_graph_mode = False
if not vllm_config.model_config.enforce_eager:
self.cuda_graph_mode = True
self.cuda_graphs = {}
self.cuda_graph_max_batch_size = padding_size(
vllm_config.scheduler_config.max_num_seqs)
self.static_cuda_buffers = {
"last_tokens":
torch.empty(self.cuda_graph_max_batch_size,
1,
dtype=torch.long),
"previous_hidden_states":
torch.empty(self.cuda_graph_max_batch_size, 1, self.emb_dim),
"next_tokens": [
torch.empty(self.cuda_graph_max_batch_size,
1,
dtype=torch.long)
for _ in range(self.n_predict)
],
}