in muse/modeling_movq.py [0:0]
def forward(self, hidden_states, zq=None):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
if zq is not None:
hidden_states = self.norm(hidden_states, zq)
else:
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
scale = 1.0 / torch.sqrt(torch.tensor(channel, dtype=hidden_states.dtype, device=hidden_states.device))
query = self.q(hidden_states)
key = self.k(hidden_states)
value = self.v(hidden_states)
if self.use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xops.memory_efficient_attention(
query, key, value, attn_bias=None, op=self.xformers_attention_op
)
else:
attention_scores = torch.baddbmm(
torch.empty(
query.shape[0],
query.shape[1],
key.shape[1],
dtype=query.dtype,
device=query.device,
),
query,
key.transpose(-1, -2),
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).view(batch, channel, height, width)
return hidden_states + residual