def forward()

in janus/janusflow/models/uvit.py [0:0]


    def forward(self, x, cond_embeds):
        x_res = x

        x = self.depthwise(x)

        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = self.channelwise_linear_1(x)
        x = self.channelwise_act(x)
        x = self.channelwise_norm(x)
        x = self.channelwise_linear_2(x)
        x = self.channelwise_dropout(x)
        x = x.permute(0, 3, 1, 2)

        x = x + x_res

        scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
        # x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
        x = torch.addcmul(
            shift[:, :, None, None], x, (1 + scale)[:, :, None, None], value=1
        )

        return x