in arctic_inference/vllm/swiftkv/llama_swiftkv.py [0:0]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=self.quant_config,
)
self.layers = torch.nn.ModuleList([
LlamaDecoderLayer(config=config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.layers.{idx}")
for idx in range(config.num_key_value_layers)
])
with model_runner.set_shift_parallel_mode(True):
self.layers.extend([
LlamaSwiftKVDecoderLayer(config=config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.layers.{idx}")
for idx in range(config.num_key_value_layers,
config.num_hidden_layers)
])
self.norm_swiftkv = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for param in self.layers[config.num_key_value_layers:].parameters():
param.shift_parallel_mode = True
self._init_prefill_runner(vllm_config)
self._init_decode_runner(vllm_config)
from arctic_inference.py_custom_ops import try_load_torch_library
self.use_custom_ops = True if try_load_torch_library() else False