def _init_weights()

in vision/m4/models/vllama3/modeling_vllama3.py [0:0]


    def _init_weights(self, module):
        def init_a_linear(module, mean=0.0, std=self.config.initializer_range):
            with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
                module.weight.data.normal_(mean=mean, std=std)
                if module.bias is not None:
                    with ContextManagers(deepspeed_gathered_parameters_context_manager(module.bias, modify=True)):
                        module.bias.data.zero_()

        if isinstance(module, MLP) or isinstance(module, SimpleMLP):
            for sub_module_name, sub_module in module.named_modules():
                if isinstance(sub_module, nn.Linear):
                    factor = 1.0
                    if "down_proj" in sub_module_name:
                        factor = 2.0
                    init_a_linear(sub_module, std=(0.4 / (self.config.hidden_size * factor)) ** 0.5)

        if isinstance(module, PerceiverResampler):
            with ContextManagers(deepspeed_gathered_parameters_context_manager(module.latents, modify=True)):
                module.latents.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
            for sub_module_name, sub_module in module.named_modules():
                if isinstance(sub_module, nn.Linear):
                    factor = 1.0
                    if "o_proj" in sub_module_name:
                        factor = 2.0 * self.config.perceiver_config.resampler_depth
                    init_a_linear(sub_module, std=(0.4 / (self.config.hidden_size * factor)) ** 0.5)
        # elif isinstance(module, DecoupledEmbedding):
        #     with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
        #         # Initialize the main embedding weights if they're not frozen
        #         if not module.partially_freeze:
        #             module.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
        #             if module.padding_idx is not None:
        #                 module.weight.data[module.padding_idx].zero_()
        #         # Initialize the additional embeddings
        #         if module.num_additional_embeddings > 0:
        #             with ContextManagers(deepspeed_gathered_parameters_context_manager(module.additional_embedding.weight, modify=True)):    
        #                 module.additional_embedding.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
           
        elif isinstance(module, nn.Embedding):
            with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
                module.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
                if module.padding_idx is not None:
                    module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.Linear):
            if module.out_features == self.out_additional_features:
                init_a_linear(module, std=(1.0 / (module.in_features)) ** 0.5)