in janus/janusflow/models/uvit.py [0:0]
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
0, 3, 1, 2
)
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states