def _init_weights()

in modules/SwissArmyTransformer/sat/model/official/t5_model.py [0:0]


    def _init_weights(self, weight, module, name):
        init_method_std = self.init_method_std
        if isinstance(module, MLP):
            if name == "dense_h_to_4h":
                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
            elif name == "dense_4h_to_h":
                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
            else:
                raise NotImplementedError(name)
        elif isinstance(module, SelfAttention):
            if name == "query_key_value":
                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
                torch.nn.init.normal_(weight[:module.inner_hidden_size], mean=0, std=init_method_std * (
                        (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5))
            elif name == "dense":
                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
            else:
                raise NotImplementedError(name)
        elif isinstance(module, CrossAttention):
            if name == "query":
                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (
                        (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5))
            elif name == "key_value":
                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5))
            elif name == "dense":
                torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5))
            else:
                raise NotImplementedError(name)
        else:
            raise NotImplementedError(module)