def _all_gather_params()

in src/nanotron/optim/zero.py [0:0]


    def _all_gather_params(self):
        """All gather updated params"""
        all_named_tensors_to_gather = [
            (name, param.view(-1))
            for named_param_groups in self.zero_named_param_groups
            for name, param in named_param_groups["named_params"]
        ]

        if len(all_named_tensors_to_gather) == 0:
            # No need to broadcast if there's nothing
            return

        if self.dp_pg.size() == 1:
            # They should already be updated
            return

        current_dp_rank = dist.get_rank(self.dp_pg)
        dist.all_gather_coalesced(
            output_tensor_lists=[
                [
                    tensor[slice(*self.param_name_to_dp_rank_offsets[name][dp_rank])]
                    if dp_rank in self.param_name_to_dp_rank_offsets[name]
                    else torch.empty(0, dtype=tensor.dtype, device=tensor.device)
                    for dp_rank in range(self.dp_pg.size())
                ]
                for name, tensor in all_named_tensors_to_gather
            ],
            input_tensor_list=[
                tensor[slice(*self.param_name_to_dp_rank_offsets[name][current_dp_rank])]
                if current_dp_rank in self.param_name_to_dp_rank_offsets[name]
                else torch.empty(0, dtype=tensor.dtype, device=tensor.device)
                for name, tensor in all_named_tensors_to_gather
            ],
            group=self.dp_pg,
        )