in src/nanotron/optim/zero.py [0:0]
def _partition_parameters(self) -> Dict[str, Dict[int, Tuple[int, int]]]:
named_params = [
(name, param)
for named_param_group in self.zero_named_param_groups
for name, param in named_param_group["named_params"]
if param.requires_grad
]
# maps each model's param to the optimizer's dp rank that is responsible for updating it
# We assume that parameters can be sharded across DP, ie we can "split" a parameter in different DP. This does break some optimizers, like Adafactor and such.
# `param_name_to_dp_rank_offsets[name]` is a `Dict[int, Tuple[int, int]]` keys are dp_rank, and `Tuple[int, int]` are the offsets of the param belonging to this DP
param_name_to_dp_rank_offsets = {}
# NOTE: save the original shapes before flattening the params
# so that later on, we can reshape the params to their original shapes
# for topology-agnostic optimizer states loading
self._orig_param_shapes = {}
for name, param in named_params:
self._orig_param_shapes[name] = param.shape
for name, param in named_params:
# We assume parameter to be contiguous in order to have an easy way of sharding it.
assert param.is_contiguous(), f"Parameter {name} is not contiguous"
numel = param.numel()
padded_numel_per_dp = (numel - 1) // self.dp_pg.size() + 1
sizes = np.full(shape=(self.dp_pg.size()), fill_value=padded_numel_per_dp)
remainder = padded_numel_per_dp * self.dp_pg.size() - numel
# Last `remainder` indices has one less element
if remainder > 0:
# It's weird that `size[-0:]` returns the entire list instead of nothing
sizes[-remainder:] -= 1
end_offsets = np.cumsum(sizes)
assert len(end_offsets) == self.dp_pg.size()
assert end_offsets[-1] == numel, f"Somehow {end_offsets[-1]} != {numel}"
# We want start indices,
start_offsets = np.concatenate([[0], end_offsets[:-1]])
param_name_to_dp_rank_offsets[name] = {
dp_rank: (start_offsets[dp_rank], end_offsets[dp_rank])
for dp_rank in range(self.dp_pg.size())
if start_offsets[dp_rank] < end_offsets[dp_rank] # Only if the slice is not empty.
}
log_rank("[ZeRO sharding] Size of optimizer params per rank:", logger=logger, level=logging.INFO, rank=0)
all_numel = sum(
param_name_to_dp_rank_offsets[name][dp_rank][1] - param_name_to_dp_rank_offsets[name][dp_rank][0]
for name, param in named_params
for dp_rank in range(self.dp_pg.size())
if dp_rank in param_name_to_dp_rank_offsets[name]
)
for dp_rank in range(self.dp_pg.size()):
acc_numel = sum(
value[dp_rank][1] - value[dp_rank][0]
for value in param_name_to_dp_rank_offsets.values()
if dp_rank in value
)
log_rank(
f"[ZeRO sharding] DP Rank {dp_rank} has {human_format(acc_numel)} out of {human_format(all_numel)} ({0 if all_numel == 0 else acc_numel / all_numel * 100:.2f}%) params' optimizer states",
logger=logger,
level=logging.INFO,
rank=0,
)
return param_name_to_dp_rank_offsets