in src/image_gen_aux/preprocessors/teed/teed.py [0:0]
def weight_init(m):
if isinstance(m, (nn.Conv2d,)):
torch.nn.init.xavier_normal_(m.weight, gain=1.0)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
# for fusion layer
if isinstance(m, (nn.ConvTranspose2d,)):
torch.nn.init.xavier_normal_(m.weight, gain=1.0)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)