def _post_backward_hook()

in fairscale/nn/data_parallel/fully_sharded_data_parallel.py [0:0]


    def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
        """
        At the start of :func:`_post_backward_hook`, ``param.grad`` contains the
        full gradient for the local batch. The reduce-scatter op will replace
        ``param.grad`` with a single shard of the summed gradient across all
        GPUs. This shard will align with the current GPU rank. For example::

            before reduce_scatter:
                param.grad (GPU #0): [1, 2, 3, 4]
                param.grad (GPU #1): [5, 6, 7, 8]

            after reduce_scatter:
                param.grad (GPU #0): [6, 8]    # 1+5, 2+6
                param.grad (GPU #1): [10, 12]  # 3+7, 4+8

        The local GPU's ``optim.step`` is responsible for updating a single
        shard of params, also corresponding to the current GPU's rank. This
        alignment is created by :func:`_shard_parameters_`, which ensures that
        the local optimizer only sees the relevant parameter shard.
        """
        # First hook callback will see PRE state. If we have multiple params,
        # then subsequent hook callbacks will see POST state.
        self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
        self.training_state = TrainingState.BACKWARD_POST
        if param.grad is None:
            return

        if hasattr(param, "_linked_param"):
            # This links to a shared param. We should finalize the linked param here.
            assert param.shape == (1,), param.shape
            # If the _is_shared flag is set, then this shared weight is indeed being
            # shared between different FSDP wrappers. Otherwise, they are linked but
            # likely in the same FSDP wrapper, which means we shouldn't finalize the
            # linked param..
            if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared:
                param = param._linked_param

        assert param.grad is not None, param.shape
        if param.grad.requires_grad:
            raise RuntimeError("FSDP only works with gradients that don't require gradients")

        if self._require_backward_grad_sync or self.reshard_after_forward:
            # Free full params. As a special case, we don't free the full params
            # when in a ``no_sync`` context (as inversely indicated by
            # ``self._require_backward_grad_sync``), since the params will not
            # get updated before the next forward. This saves networking
            # bandwidth but uses more GPU memory.
            self._free_full_params([param])

        if self.mixed_precision:
            # This is a no-op if reshard_after_forward is True, since we already
            # free the param shard when rebuilding the full params in the
            # pre_backward_hook.
            self._free_fp16_param_shard([param])

        # Switch to FP32 shard after backward.
        self._use_fp32_param_shard([param])

        if not self._require_backward_grad_sync:
            return

        # Wait for all work in the current stream to finish, then start the
        # reductions in post_backward stream.
        self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(self._streams["post_backward"]):
            orig_grad_data = param.grad.data

            if self.mixed_precision and self.fp32_reduce_scatter:
                # Cast grad to FP32.
                param.grad.data = param.grad.data.to(param.dtype)

            if self.gradient_predivide_factor > 1:
                # Average grad by world_size for consistency with PyTorch DDP.
                param.grad.data.div_(self.gradient_predivide_factor)

            if param._is_sharded:
                assert self._reducer is not None
                # Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
                # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
                # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
                # matter, neglecting rounding.
                grad = param.grad.data
                # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
                #
                # The effect on memory consumption is not usually significant. No extra memory is allocated if this
                # module is called only once, reduction happens quickly, or the tensor is bucketed. If the module is
                # called multiple times, and the backwards pass runs far enough ahead of the `post_backward` stream,
                # then we can end up with multiple unsharded gradients allocated and queued for reduction.
                #
                # We could guard against this by using CUDA events (see record_event, wait_event in torch.cuda.Stream).
                # This ensures the `default` stream will wait for the `post_backward` stream to complete the last
                # reduction for this module, before scheduling additional reduction work. Then at most there are two
                # unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
                param.grad = None
                callback_fn = functools.partial(self._post_reduction_hook, param)
                grad_chunks = chunk_and_pad(grad, self.process_group_reduce_scatter.size())
                self._reducer.reduce_scatter_async(
                    grad_chunks, group=self.process_group_reduce_scatter, callback_fn=callback_fn
                )
            else:
                # Currently the only way for _is_sharded to be False is if
                # world_size == 1. This could be relaxed in the future, in which
                # case grads should be all-reduced here.
                assert self.world_size == 1
                self._post_reduction_hook(param, param.grad.data)

            # After _post_backward_hook returns, orig_grad_data will eventually
            # go out of scope, at which point it could otherwise be freed for
            # further reuse by the main stream while the div/reduce_scatter/copy
            # are underway in the post_backward stream. See:
            # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
            orig_grad_data.record_stream(self._streams["post_backward"])