in fairscale/optim/oss.py [0:0]
def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
"""Update the consolidated state_dict list, one per rank.
Arguments:
recipient_rank (int): on which rank to materialize the full state dict.
-1 is a special value, which means that all ranks should have the state
.. warning: This needs to be called on all replicas"""
# Sync lr and other attributes in case its been updated
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
logging.debug("Pulling the sharded optimizer state from all replicas")
self._all_states = []
should_collect_state = self.rank == recipient_rank or recipient_rank == -1
should_send_state = self.rank != recipient_rank
# NCCL requires CUDA tensors for all communication primitives
dist_device = torch.device("cuda") if self.backend == dist.Backend.NCCL else self._default_device
for rank in range(self.world_size):
if rank == self.rank:
if should_collect_state:
logging.debug("Saving self state")
self._all_states.append(
recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"))
)
# Sync with other replicas
state_to_share = (
self.optim.state_dict()
if should_send_state
else torch.tensor([0], dtype=torch.uint8, device=dist_device)
)
if _gpu_capabilities_older_than_50():
_broadcast_object(
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device
)
else:
obj_list = [state_to_share]
dist.broadcast_object_list(
obj_list,
src=self.global_rank,
group=self.group,
)
else:
# Fetch the optim state from the other replicas
if _gpu_capabilities_older_than_50():
replica_state = _broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=dist_device),
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=dist_device,
)
else:
obj_list = [torch.tensor([0], dtype=torch.uint8, device=dist_device)]
dist.broadcast_object_list(
obj_list,
src=self._local_to_global_rank[rank],
group=self.group,
)
replica_state = obj_list[0]
if should_collect_state:
self._all_states.append(
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu"))
)
logging.debug("State from rank %s received", rank)