def summon_full_params()

in fairscale/nn/data_parallel/fully_sharded_data_parallel.py [0:0]


    def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
        """
        A context manager to expose full params for the current FSDP instance.
        Can be useful *after* forward/backward for a model to get the params for
        additional processing or checking. Parameters will be gathered in full
        precision (e.g., FP32).

        .. note:: This can be used on inner FSDPs.

        .. note:: This can *not* be used within a forward or backward pass. Nor
            can forward and backward be started from within this context.

        .. note:: The full parameters will be freed after the context manager
            exits; it is up to the caller to clone them if needed.

        .. note:: The full parameters can be modified, but only the portion
            corresponding to the local param shard will persist after the
            context manager exits (unless ``volatile=True``, in which case there
            are no guarantees about persistence).

        Args:
            recurse (bool, Optional): recursively summon all params for nested
                FSDP instances (default: True)
            volatile (bool, Optional): if ``True``, modifications to params are
                not guaranteed to persist after the context manager exists;
                enabling this can be slightly more efficient (default: False)
        """
        if recurse:
            with contextlib.ExitStack() as stack:
                # Summon all params for any nested FSDP instances.
                for module in self.modules():
                    if isinstance(module, FullyShardedDataParallel):
                        stack.enter_context(module.summon_full_params(recurse=False, volatile=volatile))
                # Yield to the caller, with full params in all nested instances.
                yield
            # Exiting from the ExitStack will re-shard params.
            return
        else:
            torch.cuda.synchronize()
            self._lazy_init()
            self.assert_state(TrainingState.IDLE)
            # Set the state so that we assert when trying to go into
            # forward/backward.
            self.training_state = TrainingState.SUMMON_FULL_PARAMS
            full_tensors = self._rebuild_full_params(force_full_precision=True)
            assert full_tensors is not None
            with contextlib.ExitStack() as stack:
                if self.module.is_flattened:
                    # Update flattened views to point to fully-sized tensors. We
                    # use self.params instead of full_tensors since the
                    # latter may contain padding.
                    stack.enter_context(
                        self.module.unflatten_params(
                            flat_params=[p.data for p in self.params[: self._num_flatten_params]]
                        )
                    )
                try:
                    yield
                finally:
                    stack.close()
                    non_shared_params = self.params
                    # filter out shared params for all but the owner FSDP module.
                    if len(full_tensors) < len(non_shared_params):
                        non_shared_params = self.non_shared_params()
                    assert len(full_tensors) == len(
                        non_shared_params
                    ), f"{len(full_tensors)} vs. {len(non_shared_params)}"
                    for p, (full_tensor, safe_to_free) in zip(non_shared_params, full_tensors):
                        if not volatile:
                            # Copy any changes made to the full params back into
                            # the corresponding local shards.
                            local_shard, _ = self._get_shard(full_tensor)
                            p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
                        if safe_to_free:
                            free_storage_(full_tensor)
                    self.has_full_params = False
                    if self.ssd_offload:
                        # Store tensors in the SSD buffer and free param storage.
                        for p in self.params:
                            assert isinstance(p, SsdFlatParameter)
                            p.to_file()
                    else:
                        self._use_fp32_param_shard()
                    self.training_state = TrainingState.IDLE