in optimum/neuron/models/inference/backend/modules/attention/gqa.py [0:0]
def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool:
prefix_parts = prefix.split(".")
prefix = ".".join(prefix_parts[:-1])
hf_prefix = ".".join(prefix_parts[:-2])
if self.fused_qkv:
self.replace_prefixes(
old_prefix=f"{hf_prefix}.Wqkv",
new_prefix=f"{prefix}.Wqkv",
model_state_dict=model_state_dict,
)
qkv_weight = self.get_weight(
prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict
)
q_proj_weight, k_proj_weight, v_proj_weight = qkv_weight.split(
[
self._src_num_attention_heads * self.head_dim,
self._src_num_key_value_heads * self.head_dim,
self._src_num_key_value_heads * self.head_dim,
],
dim=0,
)
qkv_bias = self.get_bias(
prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict
)
if qkv_bias is not None:
q_proj_bias, k_proj_bias, v_proj_bias = qkv_bias.split(
[
self._src_num_attention_heads * self.head_dim,
self._src_num_key_value_heads * self.head_dim,
self._src_num_key_value_heads * self.head_dim,
],
dim=0,
)
else:
q_proj_bias, k_proj_bias, v_proj_bias = None, None, None
else:
self.replace_prefixes(
old_prefix=f"{hf_prefix}.q_proj",
new_prefix=f"{prefix}.q_proj",
model_state_dict=model_state_dict,
)
self.replace_prefixes(
old_prefix=f"{hf_prefix}.k_proj",
new_prefix=f"{prefix}.k_proj",
model_state_dict=model_state_dict,
)
self.replace_prefixes(
old_prefix=f"{hf_prefix}.v_proj",
new_prefix=f"{prefix}.v_proj",
model_state_dict=model_state_dict,
)
q_proj_weight = self.get_weight(
prefix=prefix,
layer=self.q_proj,
layer_name="q_proj",
model_state_dict=model_state_dict,
)
k_proj_weight = self.get_weight(
prefix=prefix,
layer=self.k_proj,
layer_name="k_proj",
model_state_dict=model_state_dict,
)
v_proj_weight = self.get_weight(
prefix=prefix,
layer=self.v_proj,
layer_name="v_proj",
model_state_dict=model_state_dict,
)
q_proj_bias = self.get_bias(
prefix=prefix,
layer=self.q_proj,
layer_name="q_proj",
model_state_dict=model_state_dict,
)
k_proj_bias = self.get_bias(
prefix=prefix,
layer=self.k_proj,
layer_name="k_proj",
model_state_dict=model_state_dict,
)
v_proj_bias = self.get_bias(
prefix=prefix,
layer=self.v_proj,
layer_name="v_proj",
model_state_dict=model_state_dict,
)
if self.num_key_value_heads != self._src_num_key_value_heads:
if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
repeats = self.tp_degree // self._src_num_key_value_heads
elif self.sharding_strategy == GQA.CONVERT_TO_MHA:
repeats = self._src_num_attention_heads // self._src_num_key_value_heads
k_proj_weight = replicate_kv(
k_proj_weight,
source_heads=self._src_num_key_value_heads,
repeats=repeats,
head_dim=0,
)
k_proj_bias = replicate_kv(
k_proj_bias, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0
)
v_proj_weight = replicate_kv(
v_proj_weight,
source_heads=self._src_num_key_value_heads,
repeats=repeats,
head_dim=0,
)
v_proj_bias = replicate_kv(
v_proj_bias, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0
)
if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
q_proj_weight = maybe_pad_interleaved(
q_proj_weight,
pad_dim=0,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
)
q_proj_bias = maybe_pad_interleaved(
q_proj_bias,
pad_dim=0,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
)
if self.sharding_strategy == GQA.CONVERT_TO_MHA:
q_proj_weight = maybe_pad_tail(
q_proj_weight,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
pad_dim=0,
)
q_proj_bias = maybe_pad_tail(
q_proj_bias,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
pad_dim=0,
)
k_proj_weight = maybe_pad_tail(
k_proj_weight,
source_heads=self._src_num_key_value_heads,
target_heads=self.num_key_value_heads,
pad_dim=0,
)
k_proj_bias = maybe_pad_tail(
k_proj_bias,
source_heads=self._src_num_key_value_heads,
target_heads=self.num_key_value_heads,
pad_dim=0,
)
v_proj_weight = maybe_pad_tail(
v_proj_weight,
source_heads=self._src_num_key_value_heads,
target_heads=self.num_key_value_heads,
pad_dim=0,
)
v_proj_bias = maybe_pad_tail(
v_proj_bias,
source_heads=self._src_num_key_value_heads,
target_heads=self.num_key_value_heads,
pad_dim=0,
)
if self.fused_qkv:
qkv_weight = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0)
self.set_weight(
tensor=qkv_weight,
prefix=prefix,
layer=self.Wqkv,
layer_name="Wqkv",
model_state_dict=model_state_dict,
)
if self.bias:
qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=0)
self.set_bias(
tensor=qkv_bias,
prefix=prefix,
layer=self.Wqkv,
layer_name="Wqkv",
model_state_dict=model_state_dict,
)
else:
self.set_weight(
tensor=q_proj_weight,
prefix=prefix,
layer=self.q_proj,
layer_name="q_proj",
model_state_dict=model_state_dict,
)
self.set_weight(
tensor=k_proj_weight,
prefix=prefix,
layer=self.k_proj,
layer_name="k_proj",
model_state_dict=model_state_dict,
)
self.set_weight(
tensor=v_proj_weight,
prefix=prefix,
layer=self.v_proj,
layer_name="v_proj",
model_state_dict=model_state_dict,
)
if self.bias:
self.set_bias(
tensor=q_proj_bias,
prefix=prefix,
layer=self.q_proj,
layer_name="q_proj",
model_state_dict=model_state_dict,
)
self.set_bias(
tensor=k_proj_bias,
prefix=prefix,
layer=self.k_proj,
layer_name="k_proj",
model_state_dict=model_state_dict,
)
self.set_bias(
tensor=v_proj_bias,
prefix=prefix,
layer=self.v_proj,
layer_name="v_proj",
model_state_dict=model_state_dict,
)
return True