in modules/SwissArmyTransformer/sat/model/official/t5_model.py [0:0]
def _init_weights(self, weight, module, name):
init_method_std = self.init_method_std
if isinstance(module, MLP):
if name == "dense_h_to_4h":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
elif name == "dense_4h_to_h":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
else:
raise NotImplementedError(name)
elif isinstance(module, SelfAttention):
if name == "query_key_value":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
torch.nn.init.normal_(weight[:module.inner_hidden_size], mean=0, std=init_method_std * (
(module.hidden_size * module.hidden_size_per_attention_head) ** -0.5))
elif name == "dense":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
else:
raise NotImplementedError(name)
elif isinstance(module, CrossAttention):
if name == "query":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (
(module.hidden_size * module.hidden_size_per_attention_head) ** -0.5))
elif name == "key_value":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
elif name == "dense":
torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
else:
raise NotImplementedError(name)
else:
raise NotImplementedError(module)