def forward()

in optimum/exporters/neuron/model_wrappers.py [0:0]


    def forward(self, input_ids, attention_mask):
        # Infer shapes of dummy inputs used for tracing
        batch_size = input_ids.shape[0]
        sequence_length = input_ids.shape[1]
        if self.sequence_length is not None:
            assert self.sequence_length, (
                f"Different sequence length for the parallel partition({self.sequence_length}) and for dummy inputs({sequence_length}). Make sure that they have the same value."
            )
        if self.batch_size is not None:
            assert self.batch_size, (
                f"Different batch size for the parallel partition({self.batch_size}) and for dummy inputs({batch_size}). Make sure that they have the same value."
            )

        encoder_output = self.model.encoder(
            input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False
        )

        last_hidden_state = encoder_output["last_hidden_state"]
        encoder_hidden_states = torch.concat(
            [tensor.unsqueeze(0).repeat(self.num_beams, 1, 1) for tensor in last_hidden_state]
        )

        decoder_blocks = self.model.decoder.block
        present_key_value_states_sa = []
        present_key_value_states_ca = []

        for i, block in enumerate(decoder_blocks):
            # Cross attention has to be initialized with the encoder hidden state
            cross_attention: T5LayerCrossAttention = block.layer[1]
            attention = cross_attention.EncDecAttention

            def shape(states):
                """projection"""
                return states.view(
                    self.num_beams * batch_size,
                    -1,
                    self.num_attention_heads_per_partition,
                    attention.key_value_proj_dim,
                ).transpose(1, 2)

            key_states = shape(attention.k(encoder_hidden_states))
            value_states = shape(attention.v(encoder_hidden_states))

            if not self.tensor_parallel_size > 1:
                # cross_attn_kv_state
                present_key_value_states_ca.append(key_states)
                present_key_value_states_ca.append(value_states)

                # Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant.
                # The kv cache is padded here to keep a fixed shape.
                # [key states]
                present_key_value_states_sa.append(
                    torch.zeros(
                        (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv),
                        dtype=torch.float32,
                        device=self.device,
                    )
                )
                # [value states]
                present_key_value_states_sa.append(
                    torch.zeros(
                        (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv),
                        dtype=torch.float32,
                        device=self.device,
                    )
                )
            else:
                present_key_value_states_ca.append((self.past_key_values_ca[i * 2] * 0) + key_states)
                present_key_value_states_ca.append((self.past_key_values_ca[i * 2 + 1] * 0) + value_states)
                present_key_value_states_sa.append(
                    self.past_key_values_sa[i * 2]
                    * torch.zeros(
                        (
                            self.num_beams * self.batch_size,
                            self.num_attention_heads_per_partition,
                            self.sequence_length - 1,
                            self.config.d_kv,
                        ),
                        dtype=torch.float32,
                        device=self.device,
                    )
                )
                present_key_value_states_sa.append(
                    self.past_key_values_sa[i * 2 + 1]
                    * torch.zeros(
                        (
                            self.num_beams * self.batch_size,
                            self.num_attention_heads_per_partition,
                            self.sequence_length - 1,
                            self.config.d_kv,
                        ),
                        dtype=torch.float32,
                        device=self.device,
                    )
                )

        return present_key_value_states_sa + present_key_value_states_ca