in glide_text2im/clip/utils.py [0:0]
def __attrs_post_init__(self) -> None:
super().__init__()
if not self.use_admnet_init:
self.std = self.std if self.std is not None else math.sqrt(2 / (self.n_in + self.n_out))
self.std = (
self.std if self.extra_init_scale is None else self.std * self.extra_init_scale
)
w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
self.w = nn.Parameter(w)
if self.use_bias:
self.b = nn.Parameter(
torch.zeros((self.n_out,), dtype=torch.float32, device=self.device)
)
self.b.weight_decay_level = "disable" # type: ignore
else:
if self.extra_init_scale is not None:
raise ValueError("extra_init_scale incompatible with admnet init")
w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
if self.use_bias:
b = torch.empty((self.n_out,), dtype=torch.float32, device=self.device)
self.w = nn.Parameter(w)
if self.use_bias:
self.b = nn.Parameter(b)
self.b.weight_decay_level = "disable" # type: ignore