def forward()

in muse/modeling_transformer.py [0:0]


    def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None, cond_embeds=None):
        residual = hidden_states

        hidden_states = self.attn_layer_norm(hidden_states)
        if cond_embeds is not None:
            hidden_states = self.self_attn_adaLN_modulation(hidden_states, cond_embeds)
        attention_output = self.attention(hidden_states)
        if self.use_normformer:
            attention_output = self.post_attn_layer_norm(attention_output)
        hidden_states = residual + attention_output

        if encoder_hidden_states is not None:
            residual = hidden_states
            # TODO: should norm be applied to encoder_hidden_states as well?
            hidden_states = self.crossattn_layer_norm(hidden_states)
            if cond_embeds is not None:
                hidden_states = self.cross_attn_adaLN_modulation(hidden_states, cond_embeds)
            attention_output = self.crossattention(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
            )
            if self.use_normformer:
                attention_output = self.post_crossattn_layer_norm(attention_output)
            hidden_states = residual + attention_output

        residual = hidden_states
        hidden_states = self.ffn(hidden_states, cond_embeds=cond_embeds)
        hidden_states = residual + hidden_states
        return hidden_states