in glide_text2im/clip/encoders.py [0:0]
def __attrs_post_init__(self) -> None:
super().__init__()
if self.image_size % self.patch_size != 0:
raise ValueError()
n_patch = self.image_size // self.patch_size
patch_proj = torch.empty(
(self.n_state, 3) + 2 * (self.patch_size,), dtype=torch.float32, device=self.device
)
w_pos = torch.empty(
(1 + n_patch ** 2, self.n_state), dtype=torch.float32, device=self.device
)
with torch.no_grad():
if self.n_timestep == 0:
pred_state = torch.empty((self.n_state,), dtype=torch.float32, device=self.device)
pred_state.normal_(std=1 / np.sqrt(self.n_state))
self.pred_state = nn.Parameter(pred_state)
else:
w_t = torch.empty(
(self.n_timestep, self.n_state), dtype=torch.float32, device=self.device
)
w_t.normal_(std=1 / np.sqrt(self.n_state))
self.w_t = nn.Parameter(w_t)
patch_proj.normal_(std=np.sqrt(2 / (self.n_state * self.patch_size ** 2)))
w_pos.normal_(std=1 / np.sqrt(self.n_state))
self.patch_proj = nn.Parameter(patch_proj)
self.w_pos = nn.Parameter(w_pos)
self.channel_means = torch.tensor(
image_channel_means, dtype=torch.float32, device=self.device
)[None, :, None, None]
self.channel_stds = torch.tensor(
image_channel_stds, dtype=torch.float32, device=self.device
)[None, :, None, None]
self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device)