in fairscale/nn/data_parallel/fully_sharded_data_parallel.py [0:0]
def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward.
Returns:
outputs: new outputs with hooks registered if they requires gradient.
"""
if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled
if self._is_root:
# This actually means that only root instance has
# _post_backward_callback_queued defined. Accidentally accessing this field
# will assert on all other instances, giving us a nice bug checker.
self._post_backward_callback_queued = False
def _pre_backward_hook(*unused: Any) -> None:
# try to queue final backward callback only once for root, so
# that final backward callback is attached to the outer most
# backward graph task and called after all the backward
# calls are completed.
if self._is_root:
self._queue_wait_for_post_backward()
# All-gather full parameters or switching to the full params.
#
# This needs to be done on every pre_backward hook, even within the same
# iteration (i.e. for checkpointed, multiple forward pass modules). This is
# because after the forward pass (i.e. in checkpoint inner graph), we always
# switch to fp32_shard in the ``forward`` function.
#
# We used to do this only after the ``self._pre_backward_hook_has_run``
# boolean guard below, which is incorrect. It worked in pytorch < 1.9 for
# some unknown reason, but pytorch 1.10 nightly exposed this bug.
#
# Note, both ``self._rebuild_full_params`` and ``self._use_full_params`` are
# idempotent. So in case they are called unnecessarily, they don't incur much
# overhead.
if self.reshard_after_forward:
self._rebuild_full_params()
else:
self._use_full_params()
# Only run the ``self._prep_grads_for_backward`` once per iteration (i.e. in case
# it is multiple outputs or multiple forward passes).
if not self._pre_backward_hook_has_run:
self._pre_backward_hook_has_run = True
# Start of a backward pass for the first time in an iteration.
self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
# Prepare p.grad so that it is in the right shape, device, accumulated values, etc.
self._prep_grads_for_backward()
# Transition to BACKWARD_PRE state if currently IDLE. We can transition from BACKWARD_POST
# to IDLE when FSDP is within activation checkpointing and called multiple times, due to the
# extra forward pass for re-computation.
if self.training_state == TrainingState.IDLE:
self.training_state = TrainingState.BACKWARD_PRE
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
_registered = 0
def _register_hook(t: torch.Tensor) -> torch.Tensor:
# We don't register the pre_backward hook on the same tensor that has been
# returned from an inner FSDP, unless it is the first one. This does
# not cover all problematic cases though. A tensor not from an inner
# FSDP can cause problems too:
# ```
# x = layer1(input)
# state = [x] # better change to x.detach(), not fixed by the following if-condition
# x = inner_fsdp_module_layer2(x)
# state.append(x) # better change to x.detach(), but fixed by the following if-condition
# x = layer3(x)
# return x, state
# ```
# The tensors in `state`, if not detached, can be registered with
# backward hooks (in addition to the `x` on the last line). In that case,
# pre-backward hook can fire multiple times in the order that causes
# the outer FSDP to crash.
#
# The best practice is for modules to be wrapped by FSDP to return 1 and only
# 1 tensor to be used for backward. All other tensors returned should be
# detached.
nonlocal _registered
assert self._output_pre_backward_hook_registered is not None
if t.requires_grad and (_registered == 0 or id(t) not in self._output_pre_backward_hook_registered):
t.register_hook(_pre_backward_hook)
self._output_pre_backward_hook_registered.append(id(t))
_registered += 1
return t
# Attach hooks to Tensor outputs.
outputs = apply_to_tensors(_register_hook, outputs)
return outputs