in muse/modeling_transformer_v2.py [0:0]
def _init_weights(self, module):
"""
Initialize the weights according to the original implementation.
https://github.com/google-research/maskgit/blob/main/maskgit/nets/maskgit_transformer.py#L37
"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
nn.init.trunc_normal_(module.weight, std=0.02)
elif isinstance(module, (LayerNorm, RMSNorm)):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(1.0)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()