parler_tts/modeling_parler_tts.py [515:536]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        query_states = self._shape_query(query_states, tgt_len, bsz)
        if self.rope_embeddings:
            query_states = apply_rotary_pos_emb(query_states, cos, sin)

        if past_key_value is not None:
            is_updated = past_key_value.is_updated.get(self.layer_idx)
            if is_cross_attention:
                # after the first generated id, we can subsequently re-use all key/value_states from cache
                past_key_value.is_updated[self.layer_idx] = True
                past_key_value = past_key_value.cross_attention_cache
            else:
                past_key_value = past_key_value.self_attention_cache

        # use key_value_states if cross attention
        current_states = key_value_states if key_value_states is not None else hidden_states
        if is_cross_attention and past_key_value and is_updated:
            # reuse k,v, cross_attentions
            key_states = past_key_value.key_cache[self.layer_idx]
            value_states = past_key_value.value_cache[self.layer_idx]
        else:
            key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz)
            value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



parler_tts/modeling_parler_tts.py [856:878]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        query_states = self._shape_query(query_states, tgt_len, bsz)

        if self.rope_embeddings:
            query_states = apply_rotary_pos_emb(query_states, cos, sin)

        if past_key_value is not None:
            is_updated = past_key_value.is_updated.get(self.layer_idx)
            if is_cross_attention:
                # after the first generated id, we can subsequently re-use all key/value_states from cache
                past_key_value.is_updated[self.layer_idx] = True
                past_key_value = past_key_value.cross_attention_cache
            else:
                past_key_value = past_key_value.self_attention_cache

        # use key_value_states if cross attention
        current_states = key_value_states if key_value_states is not None else hidden_states
        if is_cross_attention and past_key_value and is_updated:
            # reuse k,v, cross_attentions
            key_states = past_key_value.key_cache[self.layer_idx]
            value_states = past_key_value.value_cache[self.layer_idx]
        else:
            key_states = self._shape_key_value(self.k_proj(current_states), -1, bsz)
            value_states = self._shape_key_value(self.v_proj(current_states), -1, bsz)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



