in muse/modeling_transformer.py [0:0]
def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None):
if encoder_attention_mask is not None and self.use_memory_efficient_attention_xformers:
raise ValueError("Memory efficient attention does not yet support encoder attention mask")
context = hidden_states if encoder_hidden_states is None else encoder_hidden_states
batch, q_seq_len, _ = hidden_states.shape
kv_seq_len = q_seq_len if encoder_hidden_states is None else encoder_hidden_states.shape[1]
query = self.query(hidden_states)
key = self.key(context)
value = self.value(context)
query = query.view(batch, q_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs)
key = key.view(batch, kv_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs)
value = value.view(batch, kv_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs)
if self.use_memory_efficient_attention_xformers:
attn_output = xops.memory_efficient_attention(
query, key, value, op=self.xformers_attention_op, p=self.attention_dropout if self.training else 0.0
)
attn_output = attn_output.view(batch, q_seq_len, self.hidden_size)
else:
attention_mask = None
if encoder_attention_mask is not None:
src_attn_mask = torch.ones(batch, q_seq_len, dtype=torch.long, device=query.device)
attention_mask = make_attention_mask(src_attn_mask, encoder_attention_mask, dtype=query.dtype)
attn_output = self.attention(query, key, value, attention_mask)
attn_output = self.out(attn_output)
return attn_output