in arctic_inference/vllm/swiftkv/llama_swiftkv.py [0:0]
def _init_decode_runner(self, vllm_config: VllmConfig):
vllm_config.compilation_config = copy.copy(
vllm_config.compilation_config)
vllm_config.compilation_config.inductor_compile_config = (
vllm_config.compilation_config.inductor_compile_config.copy())
self.decode_runner = LlamaSwiftKVDecodeRunner(
vllm_config=vllm_config, model=self)
config = vllm_config.model_config.hf_config
if vllm_config.compilation_config.cudagraph_capture_sizes:
self.cuda_graph_max_batch_size = max(
vllm_config.compilation_config.cudagraph_capture_sizes)
num_heads = self.layers[-1].self_attn.attn.num_kv_heads
head_size = self.layers[-1].self_attn.attn.head_size
num_kv = config.num_hidden_layers - config.num_key_value_layers
kv_size = num_kv * num_heads * head_size
self.decode_runner.inputs = {
"hidden_states": torch.empty(self.cuda_graph_max_batch_size,
config.hidden_size, device="cuda"),
"residual": torch.empty(self.cuda_graph_max_batch_size,
config.hidden_size, device="cuda"),
"positions": torch.empty(self.cuda_graph_max_batch_size,
dtype=torch.long, device="cuda"),
"k_states": torch.empty(self.cuda_graph_max_batch_size,
kv_size, device="cuda"),
"v_states": torch.empty(self.cuda_graph_max_batch_size,
kv_size, device="cuda"),
}
else:
self.cuda_graph_max_batch_size = 0