in dall_e/utils.py [0:0]
def __attrs_post_init__(self) -> None:
super().__init__()
w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32,
device=self.device, requires_grad=self.requires_grad)
w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))
b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device,
requires_grad=self.requires_grad)
self.w, self.b = nn.Parameter(w), nn.Parameter(b)