in muse/modeling_transformer.py [0:0]
def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None, cond_embeds=None):
residual = hidden_states
hidden_states = self.attn_layer_norm(hidden_states)
if cond_embeds is not None:
hidden_states = self.self_attn_adaLN_modulation(hidden_states, cond_embeds)
attention_output = self.attention(hidden_states)
if self.use_normformer:
attention_output = self.post_attn_layer_norm(attention_output)
hidden_states = residual + attention_output
if encoder_hidden_states is not None:
residual = hidden_states
# TODO: should norm be applied to encoder_hidden_states as well?
hidden_states = self.crossattn_layer_norm(hidden_states)
if cond_embeds is not None:
hidden_states = self.cross_attn_adaLN_modulation(hidden_states, cond_embeds)
attention_output = self.crossattention(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
if self.use_normformer:
attention_output = self.post_crossattn_layer_norm(attention_output)
hidden_states = residual + attention_output
residual = hidden_states
hidden_states = self.ffn(hidden_states, cond_embeds=cond_embeds)
hidden_states = residual + hidden_states
return hidden_states