in janus/janusflow/models/uvit.py [0:0]
def forward(self, x, cond_embeds):
x_res = x
x = self.depthwise(x)
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.channelwise_linear_1(x)
x = self.channelwise_act(x)
x = self.channelwise_norm(x)
x = self.channelwise_linear_2(x)
x = self.channelwise_dropout(x)
x = x.permute(0, 3, 1, 2)
x = x + x_res
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
# x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
x = torch.addcmul(
shift[:, :, None, None], x, (1 + scale)[:, :, None, None], value=1
)
return x