def preshard_hook()

in optimum/neuron/models/inference/backend/modules/attention/gqa.py [0:0]


    def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool:
        prefix_parts = prefix.split(".")
        prefix = ".".join(prefix_parts[:-1])
        hf_prefix = ".".join(prefix_parts[:-2])
        if self.fused_qkv:
            self.replace_prefixes(
                old_prefix=f"{hf_prefix}.Wqkv",
                new_prefix=f"{prefix}.Wqkv",
                model_state_dict=model_state_dict,
            )
            qkv_weight = self.get_weight(
                prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict
            )
            q_proj_weight, k_proj_weight, v_proj_weight = qkv_weight.split(
                [
                    self._src_num_attention_heads * self.head_dim,
                    self._src_num_key_value_heads * self.head_dim,
                    self._src_num_key_value_heads * self.head_dim,
                ],
                dim=0,
            )
            qkv_bias = self.get_bias(
                prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict
            )
            if qkv_bias is not None:
                q_proj_bias, k_proj_bias, v_proj_bias = qkv_bias.split(
                    [
                        self._src_num_attention_heads * self.head_dim,
                        self._src_num_key_value_heads * self.head_dim,
                        self._src_num_key_value_heads * self.head_dim,
                    ],
                    dim=0,
                )
            else:
                q_proj_bias, k_proj_bias, v_proj_bias = None, None, None
        else:
            self.replace_prefixes(
                old_prefix=f"{hf_prefix}.q_proj",
                new_prefix=f"{prefix}.q_proj",
                model_state_dict=model_state_dict,
            )
            self.replace_prefixes(
                old_prefix=f"{hf_prefix}.k_proj",
                new_prefix=f"{prefix}.k_proj",
                model_state_dict=model_state_dict,
            )
            self.replace_prefixes(
                old_prefix=f"{hf_prefix}.v_proj",
                new_prefix=f"{prefix}.v_proj",
                model_state_dict=model_state_dict,
            )

            q_proj_weight = self.get_weight(
                prefix=prefix,
                layer=self.q_proj,
                layer_name="q_proj",
                model_state_dict=model_state_dict,
            )
            k_proj_weight = self.get_weight(
                prefix=prefix,
                layer=self.k_proj,
                layer_name="k_proj",
                model_state_dict=model_state_dict,
            )
            v_proj_weight = self.get_weight(
                prefix=prefix,
                layer=self.v_proj,
                layer_name="v_proj",
                model_state_dict=model_state_dict,
            )

            q_proj_bias = self.get_bias(
                prefix=prefix,
                layer=self.q_proj,
                layer_name="q_proj",
                model_state_dict=model_state_dict,
            )
            k_proj_bias = self.get_bias(
                prefix=prefix,
                layer=self.k_proj,
                layer_name="k_proj",
                model_state_dict=model_state_dict,
            )
            v_proj_bias = self.get_bias(
                prefix=prefix,
                layer=self.v_proj,
                layer_name="v_proj",
                model_state_dict=model_state_dict,
            )

        if self.num_key_value_heads != self._src_num_key_value_heads:
            if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
                repeats = self.tp_degree // self._src_num_key_value_heads
            elif self.sharding_strategy == GQA.CONVERT_TO_MHA:
                repeats = self._src_num_attention_heads // self._src_num_key_value_heads
            k_proj_weight = replicate_kv(
                k_proj_weight,
                source_heads=self._src_num_key_value_heads,
                repeats=repeats,
                head_dim=0,
            )
            k_proj_bias = replicate_kv(
                k_proj_bias, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0
            )
            v_proj_weight = replicate_kv(
                v_proj_weight,
                source_heads=self._src_num_key_value_heads,
                repeats=repeats,
                head_dim=0,
            )
            v_proj_bias = replicate_kv(
                v_proj_bias, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0
            )

        if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
            q_proj_weight = maybe_pad_interleaved(
                q_proj_weight,
                pad_dim=0,
                source_heads=self._src_num_attention_heads,
                target_heads=self.num_attention_heads,
                source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
            )
            q_proj_bias = maybe_pad_interleaved(
                q_proj_bias,
                pad_dim=0,
                source_heads=self._src_num_attention_heads,
                target_heads=self.num_attention_heads,
                source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
            )

        if self.sharding_strategy == GQA.CONVERT_TO_MHA:
            q_proj_weight = maybe_pad_tail(
                q_proj_weight,
                source_heads=self._src_num_attention_heads,
                target_heads=self.num_attention_heads,
                pad_dim=0,
            )
            q_proj_bias = maybe_pad_tail(
                q_proj_bias,
                source_heads=self._src_num_attention_heads,
                target_heads=self.num_attention_heads,
                pad_dim=0,
            )
            k_proj_weight = maybe_pad_tail(
                k_proj_weight,
                source_heads=self._src_num_key_value_heads,
                target_heads=self.num_key_value_heads,
                pad_dim=0,
            )
            k_proj_bias = maybe_pad_tail(
                k_proj_bias,
                source_heads=self._src_num_key_value_heads,
                target_heads=self.num_key_value_heads,
                pad_dim=0,
            )
            v_proj_weight = maybe_pad_tail(
                v_proj_weight,
                source_heads=self._src_num_key_value_heads,
                target_heads=self.num_key_value_heads,
                pad_dim=0,
            )
            v_proj_bias = maybe_pad_tail(
                v_proj_bias,
                source_heads=self._src_num_key_value_heads,
                target_heads=self.num_key_value_heads,
                pad_dim=0,
            )

        if self.fused_qkv:
            qkv_weight = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0)
            self.set_weight(
                tensor=qkv_weight,
                prefix=prefix,
                layer=self.Wqkv,
                layer_name="Wqkv",
                model_state_dict=model_state_dict,
            )
            if self.bias:
                qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=0)
                self.set_bias(
                    tensor=qkv_bias,
                    prefix=prefix,
                    layer=self.Wqkv,
                    layer_name="Wqkv",
                    model_state_dict=model_state_dict,
                )
        else:
            self.set_weight(
                tensor=q_proj_weight,
                prefix=prefix,
                layer=self.q_proj,
                layer_name="q_proj",
                model_state_dict=model_state_dict,
            )
            self.set_weight(
                tensor=k_proj_weight,
                prefix=prefix,
                layer=self.k_proj,
                layer_name="k_proj",
                model_state_dict=model_state_dict,
            )
            self.set_weight(
                tensor=v_proj_weight,
                prefix=prefix,
                layer=self.v_proj,
                layer_name="v_proj",
                model_state_dict=model_state_dict,
            )

            if self.bias:
                self.set_bias(
                    tensor=q_proj_bias,
                    prefix=prefix,
                    layer=self.q_proj,
                    layer_name="q_proj",
                    model_state_dict=model_state_dict,
                )
                self.set_bias(
                    tensor=k_proj_bias,
                    prefix=prefix,
                    layer=self.k_proj,
                    layer_name="k_proj",
                    model_state_dict=model_state_dict,
                )
                self.set_bias(
                    tensor=v_proj_bias,
                    prefix=prefix,
                    layer=self.v_proj,
                    layer_name="v_proj",
                    model_state_dict=model_state_dict,
                )

        return True