in muse/modeling_movq.py [0:0]
def forward(self, hidden_states, zq=None):
residual = hidden_states
if zq is not None:
hidden_states = self.norm1(hidden_states, zq)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv1(hidden_states)
if zq is not None:
hidden_states = self.norm2(hidden_states, zq)
else:
hidden_states = self.norm2(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
residual = self.conv_shortcut(residual)
else:
residual = self.nin_shortcut(residual)
return hidden_states + residual