def forward()

in janus/models/vq_model.py [0:0]


    def forward(self, x):
        if x.dtype != torch.float32:
            x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
                torch.bfloat16
            )
        else:
            x = F.interpolate(x, scale_factor=2.0, mode="nearest")

        if self.with_conv:
            x = self.conv(x)
        return x