in muse/modeling_transformer_v2.py [0:0]
def forward(self, hidden_states, context):
batch, q_seq_len, _ = hidden_states.shape
kv_seq_len = context.shape[1]
query = self.query(hidden_states)
key = self.key(context)
value = self.value(context)
query = query.view(batch, q_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs)
key = key.view(batch, kv_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs)
value = value.view(batch, kv_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs)
if self.use_memory_efficient_attention_xformers:
attn_output = xops.memory_efficient_attention(
query,
key,
value,
op=self.xformers_attention_op,
p=self.config.attention_dropout if self.training else 0.0,
)
attn_output = attn_output.view(batch, q_seq_len, self.hidden_size)
else:
attn_output = self.attention(query, key, value)
attn_output = self.out(attn_output)
return attn_output