def __attrs_post_init__()

in glide_text2im/clip/encoders.py [0:0]


    def __attrs_post_init__(self) -> None:
        super().__init__()

        if self.image_size % self.patch_size != 0:
            raise ValueError()

        n_patch = self.image_size // self.patch_size
        patch_proj = torch.empty(
            (self.n_state, 3) + 2 * (self.patch_size,), dtype=torch.float32, device=self.device
        )
        w_pos = torch.empty(
            (1 + n_patch ** 2, self.n_state), dtype=torch.float32, device=self.device
        )

        with torch.no_grad():
            if self.n_timestep == 0:
                pred_state = torch.empty((self.n_state,), dtype=torch.float32, device=self.device)
                pred_state.normal_(std=1 / np.sqrt(self.n_state))
                self.pred_state = nn.Parameter(pred_state)
            else:
                w_t = torch.empty(
                    (self.n_timestep, self.n_state), dtype=torch.float32, device=self.device
                )
                w_t.normal_(std=1 / np.sqrt(self.n_state))
                self.w_t = nn.Parameter(w_t)

            patch_proj.normal_(std=np.sqrt(2 / (self.n_state * self.patch_size ** 2)))
            w_pos.normal_(std=1 / np.sqrt(self.n_state))

        self.patch_proj = nn.Parameter(patch_proj)
        self.w_pos = nn.Parameter(w_pos)

        self.channel_means = torch.tensor(
            image_channel_means, dtype=torch.float32, device=self.device
        )[None, :, None, None]
        self.channel_stds = torch.tensor(
            image_channel_stds, dtype=torch.float32, device=self.device
        )[None, :, None, None]
        self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device)