in janus/janusflow/models/uvit.py [0:0]
def forward(self, x, hs, t_emb):
x = x.permute(0, 2, 3, 1)
x = self.input_norm(x)
x = x.permute(0, 3, 1, 2)
x = torch.cat([x, hs.pop()], dim=1)
if self.mid_block is not None:
x = self.mid_block(x, t_emb)
for out_conv in self.out_convs:
x = out_conv(x)
assert len(hs) == 0
return x