def __init__()

in arctic_inference/vllm/spec_dec/arctic_speculator.py [0:0]


    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()

        SpeculatorTPInit.__init__(self)

        config = vllm_config.model_config.hf_config

        self.n_predict = config.n_predict
        self.vocab_size = config.vocab_size
        self.emb_dim = config.emb_dim
        self.inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim

        self.max_speculative_tokens = config.num_lookahead_tokens

        self.tie_weights = config.tie_weights
        self.scale_input = config.scale_input

        self.quantize_lm_head = True

        quant_config = Fp8ConfigWithEmbedding(
        ) if self.quantize_lm_head else None

        self.qhead = None
        if self.tie_weights:
            assert (
                self.n_predict > 1
            ), "You cannot tie weights between stages when only 1 exists"
            embedding = VocabParallelEmbedding(
                config.vocab_size,
                self.inner_dim,
                org_num_embeddings=config.vocab_size)
            self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)

            # the initial projection from the base model may
            # have a different size, so that stays separate.
            proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
            proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)

            self.proj = nn.ModuleList([proj_first] + [proj_tied] *
                                      (self.max_speculative_tokens - 1))

            head = ParallelLMHead(
                self.vocab_size,
                self.inner_dim,
                bias=False,
                quant_config=quant_config,
                skip_quantization=True,
            )
            self.head = nn.ModuleList([head] * self.max_speculative_tokens)

            if self.quantize_lm_head:
                qhead = ParallelLMHead(
                    self.vocab_size,
                    self.inner_dim,
                    bias=False,
                    quant_config=quant_config,
                    skip_quantization=False,
                )
                qhead.quant_method = OriginalFp8LinearMethod(
                    quant_config=quant_config)
                self.qhead = nn.ModuleList([qhead] *
                                           self.max_speculative_tokens)

            ln = MLPSpeculatorLayerNorm(self.inner_dim,
                                        elementwise_scale_and_shift=True)
            self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)

        else:
            self.emb = nn.ModuleList([
                VocabParallelEmbedding(
                    config.vocab_size,
                    self.inner_dim,
                    org_num_embeddings=config.vocab_size,
                ) for _ in range(self.max_speculative_tokens)
            ])

            self.proj = nn.ModuleList([
                nn.Linear(
                    (self.emb_dim if i == 0 else self.inner_dim),
                    self.inner_dim,
                    bias=False,
                ) for i in range(self.max_speculative_tokens)
            ])

            self.head = nn.ModuleList([
                ParallelLMHead(
                    self.vocab_size,
                    self.inner_dim,
                    bias=False,
                    quant_config=quant_config,
                ) for _ in range(self.max_speculative_tokens)
            ])
            self.ln = nn.ModuleList([
                MLPSpeculatorLayerNorm(self.inner_dim,
                                       elementwise_scale_and_shift=True)
                for _ in range(self.max_speculative_tokens)
            ])

        if self.scale_input:
            self.ln0 = MLPSpeculatorLayerNorm(
                self.emb_dim, elementwise_scale_and_shift=False)

        self.state_weight = 0.5**(0.5 / config.n_predict)
        self.emb_weight = math.sqrt(
            (1 - self.state_weight**2) * (self.inner_dim / 2))
        self.activation = nn.GELU()
        self.config = config
        self.logits_processor = LogitsProcessorOpt(
            vocab_size=config.vocab_size,
            org_vocab_size=config.vocab_size,
            scale=1.0,
            skip_last_gather=True,
        )
        self.sampler = get_sampler()

        self.cuda_graph_max_batch_size = 0
        self.cuda_graph_mode = False
        if not vllm_config.model_config.enforce_eager:
            self.cuda_graph_mode = True
            self.cuda_graphs = {}
            self.cuda_graph_max_batch_size = padding_size(
                vllm_config.scheduler_config.max_num_seqs)
            self.static_cuda_buffers = {
                "last_tokens":
                torch.empty(self.cuda_graph_max_batch_size,
                            1,
                            dtype=torch.long),
                "previous_hidden_states":
                torch.empty(self.cuda_graph_max_batch_size, 1, self.emb_dim),
                "next_tokens": [
                    torch.empty(self.cuda_graph_max_batch_size,
                                1,
                                dtype=torch.long)
                    for _ in range(self.n_predict)
                ],
            }