def _init_decode_runner()

in arctic_inference/vllm/swiftkv/llama_swiftkv.py [0:0]


    def _init_decode_runner(self, vllm_config: VllmConfig):
        vllm_config.compilation_config = copy.copy(
            vllm_config.compilation_config)
        vllm_config.compilation_config.inductor_compile_config = (
            vllm_config.compilation_config.inductor_compile_config.copy())
        self.decode_runner = LlamaSwiftKVDecodeRunner(
            vllm_config=vllm_config, model=self)

        config = vllm_config.model_config.hf_config
        if vllm_config.compilation_config.cudagraph_capture_sizes:
            self.cuda_graph_max_batch_size = max(
                vllm_config.compilation_config.cudagraph_capture_sizes)
            num_heads = self.layers[-1].self_attn.attn.num_kv_heads
            head_size = self.layers[-1].self_attn.attn.head_size
            num_kv = config.num_hidden_layers - config.num_key_value_layers
            kv_size = num_kv * num_heads * head_size
            self.decode_runner.inputs = {
                "hidden_states": torch.empty(self.cuda_graph_max_batch_size,
                                             config.hidden_size, device="cuda"),
                "residual": torch.empty(self.cuda_graph_max_batch_size,
                                        config.hidden_size, device="cuda"),
                "positions": torch.empty(self.cuda_graph_max_batch_size,
                                         dtype=torch.long, device="cuda"),
                "k_states": torch.empty(self.cuda_graph_max_batch_size,
                                        kv_size, device="cuda"),
                "v_states": torch.empty(self.cuda_graph_max_batch_size,
                                        kv_size, device="cuda"),
            }
        else:
            self.cuda_graph_max_batch_size = 0