def save_weights()

in src/nanotron/serialize/weights.py [0:0]


def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folder: Path):
    root_folder = root_folder / "model"

    # We save only `dist.get_rank(parallel_context.dp_cp_pg) == 0`
    if dist.get_rank(parallel_context.dp_cp_pg) != 0:
        return

    module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
    # Fix the root_model
    module_id_to_prefix[id(model)] = ""

    # We chunk everything by `tp_world_size` in order to make sure that we gather all the weights into a single device before saving it
    for name, param_or_buffer in tqdm(model.state_dict().items(), desc="Saving weights"):

        # exp_rank=0 saves all weights whereas exp_rank>0 save only MLP weights
        if dist.get_rank(parallel_context.ep_pg) != 0:
            if "experts" not in name:
                continue

        # `state_dict` doesn't return a Param or a buffer, just a tensors which loses some metadata
        try:
            param = model.get_parameter(name)
        except AttributeError:
            # TODO @nouamanetazi: Handle buffers
            param = None

        if isinstance(param, NanotronParameter):
            metadata = {}
            if param.is_tied:
                tied_info = param.get_tied_info()
                base_name = tied_info.get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
                group_ranks = tied_info.global_ranks
                group = parallel_context.world_ranks_to_pg[group_ranks]
                # Only the first rank of the group of the tied weights saves weights
                # TODO @thomasw21: We could rotate in order to balance the load.
                if dist.get_rank(group) != 0:
                    continue
            else:
                base_name = name

            if param.is_sharded:
                sharded_info: ShardedInfo = param.get_sharded_info()
                group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks]
                exp_tp_pp_rank_and_size = get_exp_tp_pp_rank_and_size_from(
                    world_rank=get_global_rank(group=group, group_rank=dist.get_rank(group)),
                    parallel_context=parallel_context,
                )
                is_expert_sharded = sharded_info.is_expert_sharded(parallel_context)
                metadata = TensorMetadata(
                    version=CHECKPOINT_VERSION,
                    local_global_slices_pairs=sharded_info.local_global_slices_pairs,
                    unsharded_shape=sharded_info.unsharded_shape,
                ).to_str_dict()

            else:
                exp_tp_pp_rank_and_size = None
                is_expert_sharded = False

            path = get_path(
                base_name,
                type=ObjectType.MODEL,
                exp_tp_pp_rank_and_size=exp_tp_pp_rank_and_size,
                is_expert_sharded=is_expert_sharded,
                prefix=root_folder,
            )
            path.parent.mkdir(exist_ok=True, parents=True)
            try:
                tensors = {"data": param_or_buffer}
                save_file(tensors=tensors, filename=path, metadata=metadata)
            except Exception as e:
                log_rank(
                    f"Error saving {path} with {metadata}",
                    logger=logger,
                    level=logging.ERROR,
                    rank=0,
                )
                raise e
        else:
            raise NotImplementedError("Parameters are required to be NanotronParameter")