in dall_e/decoder.py [0:0]
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 4:
raise ValueError(f'input shape {x.shape} is not 4d')
if x.shape[1] != self.vocab_size:
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}')
if x.dtype != torch.float32:
raise ValueError('input must have dtype torch.float32')
return self.blocks(x)