def forward()

in muse/modeling_movq.py [0:0]


    def forward(self, hidden_states, zq=None):
        residual = hidden_states
        batch, channel, height, width = hidden_states.shape
        if zq is not None:
            hidden_states = self.norm(hidden_states, zq)
        else:
            hidden_states = self.norm(hidden_states)

        hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
        scale = 1.0 / torch.sqrt(torch.tensor(channel, dtype=hidden_states.dtype, device=hidden_states.device))

        query = self.q(hidden_states)
        key = self.k(hidden_states)
        value = self.v(hidden_states)

        if self.use_memory_efficient_attention_xformers:
            # Memory efficient attention
            hidden_states = xops.memory_efficient_attention(
                query, key, value, attn_bias=None, op=self.xformers_attention_op
            )
        else:
            attention_scores = torch.baddbmm(
                torch.empty(
                    query.shape[0],
                    query.shape[1],
                    key.shape[1],
                    dtype=query.dtype,
                    device=query.device,
                ),
                query,
                key.transpose(-1, -2),
                beta=0,
                alpha=scale,
            )
            attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
            hidden_states = torch.bmm(attention_probs, value)

        hidden_states = self.proj_out(hidden_states)
        hidden_states = hidden_states.transpose(-1, -2).view(batch, channel, height, width)

        return hidden_states + residual