def __init__()

in arctic_inference/vllm/spec_dec/arctic_speculator.py [0:0]


    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()

        SpeculatorTPInit.__init__(self)

        config = vllm_config.model_config.hf_config

        self.n_predict = config.n_predict
        self.vocab_size = config.vocab_size
        self.input_hidden_dim = config.input_hidden_dim
        config.inner_dim = [int(i) for i in config.inner_dim.split(".")]
        self.inner_dim = config.inner_dim
        config.emb_dim = [int(i) for i in config.emb_dim.split(".")]
        self.emb_dim = config.emb_dim
        config.proj_dim = [int(i) for i in config.proj_dim.split(".")]
        self.proj_dim = config.proj_dim

        self.max_speculative_tokens = config.num_lookahead_tokens

        self.tie_weights = config.tie_weights
        self.tie_lstm_embs = config.tie_lstm_embs
        self.scale_input = config.scale_input
        self.quantize_lm_head = True

        quant_config = Fp8ConfigWithEmbedding(
        ) if self.quantize_lm_head else None
        self.method = getattr(config, "method", "sum_rnn")

        self.activation = nn.GELU()
        self.qhead = None
        if self.tie_weights:
            head = ParallelLMHead(
                self.vocab_size,
                self.inner_dim[-1],
                bias=False,
                quant_config=quant_config,
                skip_quantization=True,
            )
            self.head = nn.ModuleList([head] * self.max_speculative_tokens)

            if self.quantize_lm_head:
                qhead = ParallelLMHead(
                    self.vocab_size,
                    self.inner_dim[-1],
                    bias=False,
                    quant_config=quant_config,
                    skip_quantization=False,
                )
                qhead.quant_method = OriginalFp8LinearMethod(
                    quant_config=quant_config)
                self.qhead = nn.ModuleList([qhead] *
                                           self.max_speculative_tokens)
        else:
            self.head = nn.ModuleList([
                ParallelLMHead(
                    self.vocab_size,
                    self.inner_dim[-1],
                    bias=False,
                    quant_config=quant_config,
                ) for _ in range(self.max_speculative_tokens)
            ])

        if self.method == "sum_rnn":
            embs = []
            for n_i in range(self.n_predict):
                if not self.tie_weights or n_i == 0:
                    seqs = [
                        VocabParallelEmbedding(self.vocab_size,
                                               self.emb_dim[0])
                    ]
                    for i in range(1, len(self.emb_dim)):
                        print(f"ADDING ANOTHER EMB {i}")
                        seqs.append(
                            MLPSpeculatorLayerNorm(
                                self.emb_dim[i],
                                elementwise_scale_and_shift=True))
                        seqs.append(self.activation)
                        seqs.append(
                            nn.Linear(self.emb_dim[i - 1],
                                      self.emb_dim[i],
                                      bias=False))
                    embs.append(nn.Sequential(*seqs))
            self.emb = nn.ModuleList(embs)

            projs = []
            for n_i in range(self.n_predict):
                if not self.tie_weights or n_i <= 1:
                    seqs = [
                        nn.Linear(
                            (self.input_hidden_dim
                             if n_i == 0 else self.inner_dim[-1]),
                            self.proj_dim[0],
                            bias=False,
                        )
                    ]
                    for i in range(1, len(self.proj_dim)):
                        print(f"ADDING ANOTHER PROJ {i}")
                        seqs.append(
                            MLPSpeculatorLayerNorm(
                                self.proj_dim[i],
                                elementwise_scale_and_shift=True))
                        seqs.append(self.activation)
                        seqs.append(
                            nn.Linear(self.proj_dim[i - 1],
                                      self.proj_dim[i],
                                      bias=False))
                    projs.append(nn.Sequential(*seqs))
            self.proj = nn.ModuleList(projs)

            lns = []
            for n_i in range(self.n_predict):
                if not self.tie_weights or n_i == 0:
                    seqs = [
                        MLPSpeculatorLayerNorm(
                            self.inner_dim[0],
                            elementwise_scale_and_shift=True)
                    ]
                    for i in range(1, len(self.inner_dim)):
                        seqs.append(self.activation)
                        seqs.append(
                            nn.Linear(self.inner_dim[i - 1],
                                      self.inner_dim[i],
                                      bias=False))
                        seqs.append(
                            MLPSpeculatorLayerNorm(
                                self.inner_dim[i],
                                elementwise_scale_and_shift=True))
                    lns.append(nn.Sequential(*seqs))
            self.ln = nn.ModuleList(lns)

        elif self.method == "sum_lstm":
            assert self.tie_weights
            self.forget_emb = nn.ModuleList(
                [nn.Embedding(self.vocab_size, self.emb_dim[0])])
            if not self.tie_lstm_embs:
                self.input_emb = nn.ModuleList(
                    [nn.Embedding(self.vocab_size, self.emb_dim[0])])
                self.cell_emb = nn.ModuleList(
                    [nn.Embedding(self.vocab_size, self.emb_dim[0])])
                self.output_emb = nn.ModuleList(
                    [nn.Embedding(self.vocab_size, self.emb_dim[0])])
            self.projs = nn.ModuleList([
                nn.Linear(self.input_hidden_dim,
                          self.proj_dim[0] * 4,
                          bias=False),
                nn.Linear(self.inner_dim[-1], self.proj_dim[0] * 4,
                          bias=False),
            ])
            self.cell_ln = nn.ModuleList([
                MLPSpeculatorLayerNorm(self.inner_dim[0],
                                       elementwise_scale_and_shift=True)
            ])
            self.state_ln = nn.ModuleList([
                MLPSpeculatorLayerNorm(self.inner_dim[0],
                                       elementwise_scale_and_shift=True)
            ])

        if self.scale_input:
            self.ln0 = MLPSpeculatorLayerNorm(
                self.input_hidden_dim, elementwise_scale_and_shift=False)

        self.state_weight = 0.5**(0.5 / config.n_predict)
        self.emb_weight = math.sqrt(
            (1 - self.state_weight**2) * (self.inner_dim[0] / 2))
        self.config = config
        self.logits_processor = LogitsProcessorOpt(
            vocab_size=config.vocab_size,
            org_vocab_size=config.vocab_size,
            scale=1.0,
            skip_last_gather=True,
        )
        self.sampler = get_sampler()

        self.cuda_graph_max_batch_size = 0
        self.cuda_graph_mode = False

        self.cuda_graph_max_batch_size = padding_size(
            vllm_config.scheduler_config.max_num_seqs)
        self.static_cuda_buffers = {
            "last_tokens":
            torch.empty(self.cuda_graph_max_batch_size, 1, dtype=torch.long),
            "previous_hidden_states":
            torch.empty(self.cuda_graph_max_batch_size, 1,
                        self.input_hidden_dim),
            "cell_states":
            torch.empty(self.cuda_graph_max_batch_size, 1, self.inner_dim[-1]),
            "next_tokens": [
                torch.empty(self.cuda_graph_max_batch_size,
                            1,
                            dtype=torch.long) for _ in range(self.n_predict)
            ],
        }
        if self.inner_dim[-1] != self.input_hidden_dim:
            print("CREATED NEXT PREVIOUS HIDDEN STATES")
            self.static_cuda_buffers[
                "next_previous_hidden_states"] = torch.empty(
                    self.cuda_graph_max_batch_size, 1, self.inner_dim[-1])

        if not vllm_config.model_config.enforce_eager:
            self.cuda_graph_mode = True
            self.cuda_graphs = {}