in src/nanotron/optim/zero.py [0:0]
def _all_gather_params(self):
"""All gather updated params"""
all_named_tensors_to_gather = [
(name, param.view(-1))
for named_param_groups in self.zero_named_param_groups
for name, param in named_param_groups["named_params"]
]
if len(all_named_tensors_to_gather) == 0:
# No need to broadcast if there's nothing
return
if self.dp_pg.size() == 1:
# They should already be updated
return
current_dp_rank = dist.get_rank(self.dp_pg)
dist.all_gather_coalesced(
output_tensor_lists=[
[
tensor[slice(*self.param_name_to_dp_rank_offsets[name][dp_rank])]
if dp_rank in self.param_name_to_dp_rank_offsets[name]
else torch.empty(0, dtype=tensor.dtype, device=tensor.device)
for dp_rank in range(self.dp_pg.size())
]
for name, tensor in all_named_tensors_to_gather
],
input_tensor_list=[
tensor[slice(*self.param_name_to_dp_rank_offsets[name][current_dp_rank])]
if current_dp_rank in self.param_name_to_dp_rank_offsets[name]
else torch.empty(0, dtype=tensor.dtype, device=tensor.device)
for name, tensor in all_named_tensors_to_gather
],
group=self.dp_pg,
)