in dall_e/utils.py [0:0]
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_float16 and 'cuda' in self.w.device.type:
if x.dtype != torch.float16:
x = x.half()
w, b = self.w.half(), self.b.half()
else:
if x.dtype != torch.float32:
x = x.float()
w, b = self.w, self.b
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)