in src/nanotron/serialize/weights.py [0:0]
def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folder: Path):
root_folder = root_folder / "model"
# We save only `dist.get_rank(parallel_context.dp_cp_pg) == 0`
if dist.get_rank(parallel_context.dp_cp_pg) != 0:
return
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
# We chunk everything by `tp_world_size` in order to make sure that we gather all the weights into a single device before saving it
for name, param_or_buffer in tqdm(model.state_dict().items(), desc="Saving weights"):
# exp_rank=0 saves all weights whereas exp_rank>0 save only MLP weights
if dist.get_rank(parallel_context.ep_pg) != 0:
if "experts" not in name:
continue
# `state_dict` doesn't return a Param or a buffer, just a tensors which loses some metadata
try:
param = model.get_parameter(name)
except AttributeError:
# TODO @nouamanetazi: Handle buffers
param = None
if isinstance(param, NanotronParameter):
metadata = {}
if param.is_tied:
tied_info = param.get_tied_info()
base_name = tied_info.get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
group_ranks = tied_info.global_ranks
group = parallel_context.world_ranks_to_pg[group_ranks]
# Only the first rank of the group of the tied weights saves weights
# TODO @thomasw21: We could rotate in order to balance the load.
if dist.get_rank(group) != 0:
continue
else:
base_name = name
if param.is_sharded:
sharded_info: ShardedInfo = param.get_sharded_info()
group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks]
exp_tp_pp_rank_and_size = get_exp_tp_pp_rank_and_size_from(
world_rank=get_global_rank(group=group, group_rank=dist.get_rank(group)),
parallel_context=parallel_context,
)
is_expert_sharded = sharded_info.is_expert_sharded(parallel_context)
metadata = TensorMetadata(
version=CHECKPOINT_VERSION,
local_global_slices_pairs=sharded_info.local_global_slices_pairs,
unsharded_shape=sharded_info.unsharded_shape,
).to_str_dict()
else:
exp_tp_pp_rank_and_size = None
is_expert_sharded = False
path = get_path(
base_name,
type=ObjectType.MODEL,
exp_tp_pp_rank_and_size=exp_tp_pp_rank_and_size,
is_expert_sharded=is_expert_sharded,
prefix=root_folder,
)
path.parent.mkdir(exist_ok=True, parents=True)
try:
tensors = {"data": param_or_buffer}
save_file(tensors=tensors, filename=path, metadata=metadata)
except Exception as e:
log_rank(
f"Error saving {path} with {metadata}",
logger=logger,
level=logging.ERROR,
rank=0,
)
raise e
else:
raise NotImplementedError("Parameters are required to be NanotronParameter")