def __init__()

in arctic_inference/vllm/swiftkv/llama_swiftkv.py [0:0]


    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        self.vllm_config = vllm_config
        config = vllm_config.model_config.hf_config
        self.quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        self.padding_idx = config.pad_token_id
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            quant_config=self.quant_config,
        )
        self.layers = torch.nn.ModuleList([
            LlamaDecoderLayer(config=config,
                              cache_config=vllm_config.cache_config,
                              quant_config=vllm_config.quant_config,
                              prefix=f"{prefix}.layers.{idx}")
            for idx in range(config.num_key_value_layers)
        ])
        with model_runner.set_shift_parallel_mode(True):
            self.layers.extend([
                LlamaSwiftKVDecoderLayer(config=config,
                                         cache_config=vllm_config.cache_config,
                                         quant_config=vllm_config.quant_config,
                                         prefix=f"{prefix}.layers.{idx}")
                for idx in range(config.num_key_value_layers,
                                 config.num_hidden_layers)
            ])
            self.norm_swiftkv = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        for param in self.layers[config.num_key_value_layers:].parameters():
            param.shift_parallel_mode = True

        self._init_prefill_runner(vllm_config)
        self._init_decode_runner(vllm_config)

        from arctic_inference.py_custom_ops import try_load_torch_library
        self.use_custom_ops = True if try_load_torch_library() else False