in glide_text2im/clip/encoders.py [0:0]
def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor:
if len(x.shape) != 4:
raise ValueError("input should be 4d")
if x.shape[1] != 3:
raise ValueError("input should have 3 channels")
if not (x.shape[2] == self.image_size and x.shape[3] == self.image_size):
raise ValueError(f"input is not {self.image_size} x {self.image_size}")
if (self.n_timestep == 0 and t is not None) or (self.n_timestep != 0 and t is None):
raise ValueError()
if self.n_timestep != 0:
assert t is not None
if len(t.shape) != 1:
raise ValueError()
if t.shape[0] != x.shape[0]:
raise ValueError()
x = (x - self.channel_means) / self.channel_stds
x = F.conv2d(x, self.patch_proj, stride=self.patch_size)
x = x.reshape(x.shape[0], self.n_state, (self.image_size // self.patch_size) ** 2).permute(
0, 2, 1
)
sot = (
self.pred_state[None, None].expand(x.shape[0], -1, -1)
if self.n_timestep == 0
else F.embedding(cast(torch.Tensor, t), self.w_t)[:, None]
)
x = torch.cat((sot, x), dim=1) + self.w_pos[None]
return self.ln(x)