in arctic_inference/vllm/spec_dec/arctic_speculator.py [0:0]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
SpeculatorTPInit.__init__(self)
config = vllm_config.model_config.hf_config
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.input_hidden_dim = config.input_hidden_dim
config.inner_dim = [int(i) for i in config.inner_dim.split(".")]
self.inner_dim = config.inner_dim
config.emb_dim = [int(i) for i in config.emb_dim.split(".")]
self.emb_dim = config.emb_dim
config.proj_dim = [int(i) for i in config.proj_dim.split(".")]
self.proj_dim = config.proj_dim
self.max_speculative_tokens = config.num_lookahead_tokens
self.tie_weights = config.tie_weights
self.tie_lstm_embs = config.tie_lstm_embs
self.scale_input = config.scale_input
self.quantize_lm_head = True
quant_config = Fp8ConfigWithEmbedding(
) if self.quantize_lm_head else None
self.method = getattr(config, "method", "sum_rnn")
self.activation = nn.GELU()
self.qhead = None
if self.tie_weights:
head = ParallelLMHead(
self.vocab_size,
self.inner_dim[-1],
bias=False,
quant_config=quant_config,
skip_quantization=True,
)
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
if self.quantize_lm_head:
qhead = ParallelLMHead(
self.vocab_size,
self.inner_dim[-1],
bias=False,
quant_config=quant_config,
skip_quantization=False,
)
qhead.quant_method = OriginalFp8LinearMethod(
quant_config=quant_config)
self.qhead = nn.ModuleList([qhead] *
self.max_speculative_tokens)
else:
self.head = nn.ModuleList([
ParallelLMHead(
self.vocab_size,
self.inner_dim[-1],
bias=False,
quant_config=quant_config,
) for _ in range(self.max_speculative_tokens)
])
if self.method == "sum_rnn":
embs = []
for n_i in range(self.n_predict):
if not self.tie_weights or n_i == 0:
seqs = [
VocabParallelEmbedding(self.vocab_size,
self.emb_dim[0])
]
for i in range(1, len(self.emb_dim)):
print(f"ADDING ANOTHER EMB {i}")
seqs.append(
MLPSpeculatorLayerNorm(
self.emb_dim[i],
elementwise_scale_and_shift=True))
seqs.append(self.activation)
seqs.append(
nn.Linear(self.emb_dim[i - 1],
self.emb_dim[i],
bias=False))
embs.append(nn.Sequential(*seqs))
self.emb = nn.ModuleList(embs)
projs = []
for n_i in range(self.n_predict):
if not self.tie_weights or n_i <= 1:
seqs = [
nn.Linear(
(self.input_hidden_dim
if n_i == 0 else self.inner_dim[-1]),
self.proj_dim[0],
bias=False,
)
]
for i in range(1, len(self.proj_dim)):
print(f"ADDING ANOTHER PROJ {i}")
seqs.append(
MLPSpeculatorLayerNorm(
self.proj_dim[i],
elementwise_scale_and_shift=True))
seqs.append(self.activation)
seqs.append(
nn.Linear(self.proj_dim[i - 1],
self.proj_dim[i],
bias=False))
projs.append(nn.Sequential(*seqs))
self.proj = nn.ModuleList(projs)
lns = []
for n_i in range(self.n_predict):
if not self.tie_weights or n_i == 0:
seqs = [
MLPSpeculatorLayerNorm(
self.inner_dim[0],
elementwise_scale_and_shift=True)
]
for i in range(1, len(self.inner_dim)):
seqs.append(self.activation)
seqs.append(
nn.Linear(self.inner_dim[i - 1],
self.inner_dim[i],
bias=False))
seqs.append(
MLPSpeculatorLayerNorm(
self.inner_dim[i],
elementwise_scale_and_shift=True))
lns.append(nn.Sequential(*seqs))
self.ln = nn.ModuleList(lns)
elif self.method == "sum_lstm":
assert self.tie_weights
self.forget_emb = nn.ModuleList(
[nn.Embedding(self.vocab_size, self.emb_dim[0])])
if not self.tie_lstm_embs:
self.input_emb = nn.ModuleList(
[nn.Embedding(self.vocab_size, self.emb_dim[0])])
self.cell_emb = nn.ModuleList(
[nn.Embedding(self.vocab_size, self.emb_dim[0])])
self.output_emb = nn.ModuleList(
[nn.Embedding(self.vocab_size, self.emb_dim[0])])
self.projs = nn.ModuleList([
nn.Linear(self.input_hidden_dim,
self.proj_dim[0] * 4,
bias=False),
nn.Linear(self.inner_dim[-1], self.proj_dim[0] * 4,
bias=False),
])
self.cell_ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim[0],
elementwise_scale_and_shift=True)
])
self.state_ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim[0],
elementwise_scale_and_shift=True)
])
if self.scale_input:
self.ln0 = MLPSpeculatorLayerNorm(
self.input_hidden_dim, elementwise_scale_and_shift=False)
self.state_weight = 0.5**(0.5 / config.n_predict)
self.emb_weight = math.sqrt(
(1 - self.state_weight**2) * (self.inner_dim[0] / 2))
self.config = config
self.logits_processor = LogitsProcessorOpt(
vocab_size=config.vocab_size,
org_vocab_size=config.vocab_size,
scale=1.0,
skip_last_gather=True,
)
self.sampler = get_sampler()
self.cuda_graph_max_batch_size = 0
self.cuda_graph_mode = False
self.cuda_graph_max_batch_size = padding_size(
vllm_config.scheduler_config.max_num_seqs)
self.static_cuda_buffers = {
"last_tokens":
torch.empty(self.cuda_graph_max_batch_size, 1, dtype=torch.long),
"previous_hidden_states":
torch.empty(self.cuda_graph_max_batch_size, 1,
self.input_hidden_dim),
"cell_states":
torch.empty(self.cuda_graph_max_batch_size, 1, self.inner_dim[-1]),
"next_tokens": [
torch.empty(self.cuda_graph_max_batch_size,
1,
dtype=torch.long) for _ in range(self.n_predict)
],
}
if self.inner_dim[-1] != self.input_hidden_dim:
print("CREATED NEXT PREVIOUS HIDDEN STATES")
self.static_cuda_buffers[
"next_previous_hidden_states"] = torch.empty(
self.cuda_graph_max_batch_size, 1, self.inner_dim[-1])
if not vllm_config.model_config.enforce_eager:
self.cuda_graph_mode = True
self.cuda_graphs = {}