def _partition_parameters()

in src/nanotron/optim/zero.py [0:0]


    def _partition_parameters(self) -> Dict[str, Dict[int, Tuple[int, int]]]:
        named_params = [
            (name, param)
            for named_param_group in self.zero_named_param_groups
            for name, param in named_param_group["named_params"]
            if param.requires_grad
        ]

        # maps each model's param to the optimizer's dp rank that is responsible for updating it

        # We assume that parameters can be sharded across DP, ie we can "split" a parameter in different DP. This does break some optimizers, like Adafactor and such.
        # `param_name_to_dp_rank_offsets[name]` is a `Dict[int, Tuple[int, int]]` keys are dp_rank, and `Tuple[int, int]` are the offsets of the param belonging to this DP
        param_name_to_dp_rank_offsets = {}

        # NOTE: save the original shapes before flattening the params
        # so that later on, we can reshape the params to their original shapes
        # for topology-agnostic optimizer states loading
        self._orig_param_shapes = {}
        for name, param in named_params:
            self._orig_param_shapes[name] = param.shape

        for name, param in named_params:
            # We assume parameter to be contiguous in order to have an easy way of sharding it.
            assert param.is_contiguous(), f"Parameter {name} is not contiguous"

            numel = param.numel()
            padded_numel_per_dp = (numel - 1) // self.dp_pg.size() + 1
            sizes = np.full(shape=(self.dp_pg.size()), fill_value=padded_numel_per_dp)
            remainder = padded_numel_per_dp * self.dp_pg.size() - numel
            # Last `remainder` indices has one less element
            if remainder > 0:
                # It's weird that `size[-0:]` returns the entire list instead of nothing
                sizes[-remainder:] -= 1
            end_offsets = np.cumsum(sizes)
            assert len(end_offsets) == self.dp_pg.size()
            assert end_offsets[-1] == numel, f"Somehow {end_offsets[-1]} != {numel}"
            # We want start indices,
            start_offsets = np.concatenate([[0], end_offsets[:-1]])

            param_name_to_dp_rank_offsets[name] = {
                dp_rank: (start_offsets[dp_rank], end_offsets[dp_rank])
                for dp_rank in range(self.dp_pg.size())
                if start_offsets[dp_rank] < end_offsets[dp_rank]  # Only if the slice is not empty.
            }

        log_rank("[ZeRO sharding] Size of optimizer params per rank:", logger=logger, level=logging.INFO, rank=0)
        all_numel = sum(
            param_name_to_dp_rank_offsets[name][dp_rank][1] - param_name_to_dp_rank_offsets[name][dp_rank][0]
            for name, param in named_params
            for dp_rank in range(self.dp_pg.size())
            if dp_rank in param_name_to_dp_rank_offsets[name]
        )
        for dp_rank in range(self.dp_pg.size()):
            acc_numel = sum(
                value[dp_rank][1] - value[dp_rank][0]
                for value in param_name_to_dp_rank_offsets.values()
                if dp_rank in value
            )
            log_rank(
                f"[ZeRO sharding] DP Rank {dp_rank} has {human_format(acc_numel)} out of {human_format(all_numel)} ({0 if all_numel == 0 else acc_numel / all_numel * 100:.2f}%) params' optimizer states",
                logger=logger,
                level=logging.INFO,
                rank=0,
            )

        return param_name_to_dp_rank_offsets