in src/nanotron/trainer.py [0:0]
def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel:
unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model
# Load or initialize model weights
reloaded_from_checkpoint = False
if self.init_checkpoint_path is not None:
# Load from a pre existing checkpoint
if check_path_is_local(self.init_checkpoint_path):
# Reload from a training checkpoint
log_rank(
f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0
)
self.param_shard_metadata = load_weights(
model=unwrapped_model,
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
)
reloaded_from_checkpoint = True
if not reloaded_from_checkpoint:
log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0)
if self.parallel_context.context_parallel_size > 1:
raise NotImplementedError("Init with Context parallel size > 1 not supported yet")
if isinstance(self.config.model.init_method, ExistingCheckpointInit):
# Initialize model from an pretrained model checkpoint (without optimizer, lr_scheduler...)
self.param_shard_metadata = load_weights(
model=unwrapped_model,
parallel_context=self.parallel_context,
root_folder=self.config.model.init_method.path,
)
elif isinstance(self.config.model.init_method, (RandomInit, SpectralMupInit)):
unwrapped_model.init_model_randomly(config=self.config)
# Synchronize parameters so that the model is consistent
# sync all params across dp
for _, param in sorted(model.named_parameters(), key=lambda x: x[0]):
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=self.parallel_context.dp_pg)
# sync tied params across tied groups
for (_, group_ranks), param in sorted(
get_tied_id_to_param(
parameters=model.parameters(),
root_module=unwrapped_model,
).items(),
key=lambda x: x[0],
):
group = self.parallel_context.world_ranks_to_pg[group_ranks]
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group)
else:
raise ValueError(f"Unsupported {self.config.model.init_method}")
return model