def forward()

in janus/janusflow/models/uvit.py [0:0]


    def forward(self, x, hs, t_emb):

        x = x.permute(0, 2, 3, 1)
        x = self.input_norm(x)
        x = x.permute(0, 3, 1, 2)

        x = torch.cat([x, hs.pop()], dim=1)
        if self.mid_block is not None:
            x = self.mid_block(x, t_emb)
        for out_conv in self.out_convs:
            x = out_conv(x)
        assert len(hs) == 0
        return x