in fairscale/nn/data_parallel/fully_sharded_data_parallel.py [0:0]
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
"""
At the start of :func:`_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will replace
``param.grad`` with a single shard of the summed gradient across all
GPUs. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
param.grad (GPU #1): [5, 6, 7, 8]
after reduce_scatter:
param.grad (GPU #0): [6, 8] # 1+5, 2+6
param.grad (GPU #1): [10, 12] # 3+7, 4+8
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by :func:`_shard_parameters_`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
# First hook callback will see PRE state. If we have multiple params,
# then subsequent hook callbacks will see POST state.
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
if param.grad is None:
return
if hasattr(param, "_linked_param"):
# This links to a shared param. We should finalize the linked param here.
assert param.shape == (1,), param.shape
# If the _is_shared flag is set, then this shared weight is indeed being
# shared between different FSDP wrappers. Otherwise, they are linked but
# likely in the same FSDP wrapper, which means we shouldn't finalize the
# linked param..
if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared:
param = param._linked_param
assert param.grad is not None, param.shape
if param.grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients")
if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by
# ``self._require_backward_grad_sync``), since the params will not
# get updated before the next forward. This saves networking
# bandwidth but uses more GPU memory.
self._free_full_params([param])
if self.mixed_precision:
# This is a no-op if reshard_after_forward is True, since we already
# free the param shard when rebuilding the full params in the
# pre_backward_hook.
self._free_fp16_param_shard([param])
# Switch to FP32 shard after backward.
self._use_fp32_param_shard([param])
if not self._require_backward_grad_sync:
return
# Wait for all work in the current stream to finish, then start the
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["post_backward"]):
orig_grad_data = param.grad.data
if self.mixed_precision and self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.to(param.dtype)
if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_predivide_factor)
if param._is_sharded:
assert self._reducer is not None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
grad = param.grad.data
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
# module is called only once, reduction happens quickly, or the tensor is bucketed. If the module is
# called multiple times, and the backwards pass runs far enough ahead of the `post_backward` stream,
# then we can end up with multiple unsharded gradients allocated and queued for reduction.
#
# We could guard against this by using CUDA events (see record_event, wait_event in torch.cuda.Stream).
# This ensures the `default` stream will wait for the `post_backward` stream to complete the last
# reduction for this module, before scheduling additional reduction work. Then at most there are two
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
param.grad = None
callback_fn = functools.partial(self._post_reduction_hook, param)
grad_chunks = chunk_and_pad(grad, self.process_group_reduce_scatter.size())
self._reducer.reduce_scatter_async(
grad_chunks, group=self.process_group_reduce_scatter, callback_fn=callback_fn
)
else:
# Currently the only way for _is_sharded to be False is if
# world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here.
assert self.world_size == 1
self._post_reduction_hook(param, param.grad.data)
# After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for
# further reuse by the main stream while the div/reduce_scatter/copy
# are underway in the post_backward stream. See:
# github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
orig_grad_data.record_stream(self._streams["post_backward"])