in vision/m4/models/vllama3/modeling_vllama3.py [0:0]
def _init_weights(self, module):
def init_a_linear(module, mean=0.0, std=self.config.initializer_range):
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
module.weight.data.normal_(mean=mean, std=std)
if module.bias is not None:
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.bias, modify=True)):
module.bias.data.zero_()
if isinstance(module, MLP) or isinstance(module, SimpleMLP):
for sub_module_name, sub_module in module.named_modules():
if isinstance(sub_module, nn.Linear):
factor = 1.0
if "down_proj" in sub_module_name:
factor = 2.0
init_a_linear(sub_module, std=(0.4 / (self.config.hidden_size * factor)) ** 0.5)
if isinstance(module, PerceiverResampler):
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.latents, modify=True)):
module.latents.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
for sub_module_name, sub_module in module.named_modules():
if isinstance(sub_module, nn.Linear):
factor = 1.0
if "o_proj" in sub_module_name:
factor = 2.0 * self.config.perceiver_config.resampler_depth
init_a_linear(sub_module, std=(0.4 / (self.config.hidden_size * factor)) ** 0.5)
# elif isinstance(module, DecoupledEmbedding):
# with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
# # Initialize the main embedding weights if they're not frozen
# if not module.partially_freeze:
# module.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
# if module.padding_idx is not None:
# module.weight.data[module.padding_idx].zero_()
# # Initialize the additional embeddings
# if module.num_additional_embeddings > 0:
# with ContextManagers(deepspeed_gathered_parameters_context_manager(module.additional_embedding.weight, modify=True)):
# module.additional_embedding.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
elif isinstance(module, nn.Embedding):
with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
module.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.Linear):
if module.out_features == self.out_additional_features:
init_a_linear(module, std=(1.0 / (module.in_features)) ** 0.5)