def forward()

in muse/modeling_transformer_v2.py [0:0]


    def forward(self, hidden_states, context):
        batch, q_seq_len, _ = hidden_states.shape
        kv_seq_len = context.shape[1]

        query = self.query(hidden_states)
        key = self.key(context)
        value = self.value(context)

        query = query.view(batch, q_seq_len, self.num_heads, self.head_dim)  # (B, T, nh, hs)
        key = key.view(batch, kv_seq_len, self.num_heads, self.head_dim)  # (B, T, nh, hs)
        value = value.view(batch, kv_seq_len, self.num_heads, self.head_dim)  # (B, T, nh, hs)

        if self.use_memory_efficient_attention_xformers:
            attn_output = xops.memory_efficient_attention(
                query,
                key,
                value,
                op=self.xformers_attention_op,
                p=self.config.attention_dropout if self.training else 0.0,
            )
            attn_output = attn_output.view(batch, q_seq_len, self.hidden_size)
        else:
            attn_output = self.attention(query, key, value)

        attn_output = self.out(attn_output)
        return attn_output