in arctic_inference/vllm/swiftkv/llama_swiftkv.py [0:0]
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj.", ".q_proj.", "q"),
(".qkv_proj.", ".k_proj.", "k"),
(".qkv_proj.", ".v_proj.", "v"),
(".gate_up_proj.", ".gate_proj.", 0),
(".gate_up_proj.", ".up_proj.", 1),
(".kv_proj_swiftkv.", ".k_proj_swiftkv.", "k"),
(".kv_proj_swiftkv.", ".v_proj_swiftkv.", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
use_shift_mode = getattr(param, "shift_parallel_mode", None)
with model_runner.set_shift_parallel_mode(use_shift_mode):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
use_shift_mode = getattr(param, "shift_parallel_mode", None)
with model_runner.set_shift_parallel_mode(use_shift_mode):
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
use_shift_mode = getattr(param, "shift_parallel_mode", None)
with model_runner.set_shift_parallel_mode(use_shift_mode):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params