def forward()

in glide_text2im/clip/encoders.py [0:0]


    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor:
        if len(x.shape) != 4:
            raise ValueError("input should be 4d")
        if x.shape[1] != 3:
            raise ValueError("input should have 3 channels")
        if not (x.shape[2] == self.image_size and x.shape[3] == self.image_size):
            raise ValueError(f"input is not {self.image_size} x {self.image_size}")

        if (self.n_timestep == 0 and t is not None) or (self.n_timestep != 0 and t is None):
            raise ValueError()
        if self.n_timestep != 0:
            assert t is not None
            if len(t.shape) != 1:
                raise ValueError()
            if t.shape[0] != x.shape[0]:
                raise ValueError()

        x = (x - self.channel_means) / self.channel_stds
        x = F.conv2d(x, self.patch_proj, stride=self.patch_size)
        x = x.reshape(x.shape[0], self.n_state, (self.image_size // self.patch_size) ** 2).permute(
            0, 2, 1
        )

        sot = (
            self.pred_state[None, None].expand(x.shape[0], -1, -1)
            if self.n_timestep == 0
            else F.embedding(cast(torch.Tensor, t), self.w_t)[:, None]
        )
        x = torch.cat((sot, x), dim=1) + self.w_pos[None]
        return self.ln(x)