in janus/models/vq_model.py [0:0]
def forward(self, x):
if x.dtype != torch.float32:
x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
torch.bfloat16
)
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x