def consolidate_state_dict()

in fairscale/optim/oss.py [0:0]


    def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
        """Update the consolidated state_dict list, one per rank.

        Arguments:
            recipient_rank (int): on which rank to materialize the full state dict.
            -1 is a special value, which means that all ranks should have the state

        .. warning: This needs to be called on all replicas"""

        # Sync lr and other attributes in case its been updated
        OSS._sync_param_groups(self.param_groups, self.optim.param_groups)

        # Pull the sharded state from all the other replicas
        # Store all the states in order, rank by rank
        logging.debug("Pulling the sharded optimizer state from all replicas")

        self._all_states = []
        should_collect_state = self.rank == recipient_rank or recipient_rank == -1
        should_send_state = self.rank != recipient_rank

        # NCCL requires CUDA tensors for all communication primitives
        dist_device = torch.device("cuda") if self.backend == dist.Backend.NCCL else self._default_device

        for rank in range(self.world_size):
            if rank == self.rank:
                if should_collect_state:
                    logging.debug("Saving self state")
                    self._all_states.append(
                        recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"))
                    )

                # Sync with other replicas
                state_to_share = (
                    self.optim.state_dict()
                    if should_send_state
                    else torch.tensor([0], dtype=torch.uint8, device=dist_device)
                )
                if _gpu_capabilities_older_than_50():
                    _broadcast_object(
                        state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device
                    )
                else:
                    obj_list = [state_to_share]
                    dist.broadcast_object_list(
                        obj_list,
                        src=self.global_rank,
                        group=self.group,
                    )
            else:
                # Fetch the optim state from the other replicas
                if _gpu_capabilities_older_than_50():
                    replica_state = _broadcast_object(
                        torch.tensor([0], dtype=torch.uint8, device=dist_device),
                        src_rank=self._local_to_global_rank[rank],
                        group=self.group,
                        dist_device=dist_device,
                    )
                else:
                    obj_list = [torch.tensor([0], dtype=torch.uint8, device=dist_device)]
                    dist.broadcast_object_list(
                        obj_list,
                        src=self._local_to_global_rank[rank],
                        group=self.group,
                    )
                    replica_state = obj_list[0]

                if should_collect_state:
                    self._all_states.append(
                        recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
                    )

                logging.debug("State from rank %s received", rank)