in fairscale/nn/data_parallel/fully_sharded_data_parallel.py [0:0]
def _set_is_root(self) -> None:
"""If ``True``, implies that no other :class:`FullyShardedDataParallel`
instance wraps this one. Called once by :func:`_lazy_init`.
Also sets self.children_share_process_group = True if all child
instances share the same process group. If some child instances use a
different process group, self.clip_grad_norm_ will raise an error.
"""
if self._is_root is not None:
return
# No FSDP instance wraps this, else _is_root would be set to False.
self._is_root = True
# If final backward callback is never been queued, state should be IDLE.
# If final backward callback is queued, the callback should be finished
# and the state was reset to be IDLE.
# This should be asserted at the beginning of forward pass in the root instance only.
# For children instances, if they are checkpointed, state will not be reset to
# IDLE after each inner forward/backward.
self.assert_state(TrainingState.IDLE)
# As the root, we now set all children instances to False and
# give them a closure to try to queue a wait_for_post_backward.
self.children_share_process_group = True
for n, m in self.named_modules():
# `n != ""` excludes self.
if n != "" and isinstance(m, FullyShardedDataParallel):
# We relax the assert for non-root instance, when the nested inialized module is wrapped
# again in FSDP later, for example after training to run inference.
assert m._is_root is None or not m._is_root
if m._is_root is None:
m._is_root = False
if m.process_group != self.process_group:
self.children_share_process_group = False
# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
# Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
(m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
)