in fairscale/nn/data_parallel/fully_sharded_data_parallel.py [0:0]
def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
"""
A context manager to expose full params for the current FSDP instance.
Can be useful *after* forward/backward for a model to get the params for
additional processing or checking. Parameters will be gathered in full
precision (e.g., FP32).
.. note:: This can be used on inner FSDPs.
.. note:: This can *not* be used within a forward or backward pass. Nor
can forward and backward be started from within this context.
.. note:: The full parameters will be freed after the context manager
exits; it is up to the caller to clone them if needed.
.. note:: The full parameters can be modified, but only the portion
corresponding to the local param shard will persist after the
context manager exits (unless ``volatile=True``, in which case there
are no guarantees about persistence).
Args:
recurse (bool, Optional): recursively summon all params for nested
FSDP instances (default: True)
volatile (bool, Optional): if ``True``, modifications to params are
not guaranteed to persist after the context manager exists;
enabling this can be slightly more efficient (default: False)
"""
if recurse:
with contextlib.ExitStack() as stack:
# Summon all params for any nested FSDP instances.
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module.summon_full_params(recurse=False, volatile=volatile))
# Yield to the caller, with full params in all nested instances.
yield
# Exiting from the ExitStack will re-shard params.
return
else:
torch.cuda.synchronize()
self._lazy_init()
self.assert_state(TrainingState.IDLE)
# Set the state so that we assert when trying to go into
# forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS
full_tensors = self._rebuild_full_params(force_full_precision=True)
assert full_tensors is not None
with contextlib.ExitStack() as stack:
if self.module.is_flattened:
# Update flattened views to point to fully-sized tensors. We
# use self.params instead of full_tensors since the
# latter may contain padding.
stack.enter_context(
self.module.unflatten_params(
flat_params=[p.data for p in self.params[: self._num_flatten_params]]
)
)
try:
yield
finally:
stack.close()
non_shared_params = self.params
# filter out shared params for all but the owner FSDP module.
if len(full_tensors) < len(non_shared_params):
non_shared_params = self.non_shared_params()
assert len(full_tensors) == len(
non_shared_params
), f"{len(full_tensors)} vs. {len(non_shared_params)}"
for p, (full_tensor, safe_to_free) in zip(non_shared_params, full_tensors):
if not volatile:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard, _ = self._get_shard(full_tensor)
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
self.has_full_params = False
if self.ssd_offload:
# Store tensors in the SSD buffer and free param storage.
for p in self.params:
assert isinstance(p, SsdFlatParameter)
p.to_file()
else:
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE