in src/nanotron/models/starcoder2.py [0:0]
def init_model_randomly(self, config):
"""Initialize model parameters randomly.
Note:
Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
"""
model = self
initialized_parameters = set()
# Handle tensor parallelism
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
std = config.model.init_method.std
sigma = config.model.init_method.std
num_layers = config.model.model_config.num_hidden_layers
for param_name, param in model.named_parameters():
assert isinstance(param, NanotronParameter)
module_name, param_name = param_name.rsplit(".", 1)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if full_param_name in initialized_parameters:
# Already initialized
continue
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelRowLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers))
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, LayerNorm):
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, MQAColumnLinears):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelEmbedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
else:
raise Exception(f"Parameter {full_param_name} was not initialized")
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
assert initialized_parameters == {
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
if param.is_tied
else name
for name, param in model.named_parameters()
}, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"