in fairscale/nn/data_parallel/fully_sharded_data_parallel.py [0:0]
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._lazy_init()
# Start of a forward pass.
self.training_state = TrainingState.FORWARD
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
if self._is_root and self.mixed_precision:
args, kwargs = cast_floats_to_right_precision(True, True, *args, **kwargs)
# If enabled, convert the input to FP32 if we are in full precision.
# no_grad is not used because the input might be for a non-root instance,
# which mean autograd needs to go through the conversion.
if self.force_input_to_fp32 and not self.mixed_precision:
args, kwargs = cast_floats_to_right_precision(False, False, *args, **kwargs)
# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
self._rebuild_full_params()
# Register backward hooks to reshard params and reduce-scatter grads.
# These need to be re-registered every forward pass.
self._register_post_backward_hooks()
outputs = self.module(*args, **kwargs)
if self.reshard_after_forward:
self._free_full_params()
if self.mixed_precision or self.move_params_to_cpu:
self._free_fp16_param_shard()
# Switch to main FP32 param shard. We maintain this invariant throughout
# the code, i.e., ``p.data == p._fp32_shard`` after each function. This
# also ensures that after the first forward, the optimizer state will be
# initialized with the correct dtype and (sharded) size, since optimizer
# state is typically initialized lazily in ``optim.step()``.
self._use_fp32_param_shard()
# Register pre-backward hooks to all-gather the params for the backward
# pass (if output's grad was needed). This won't register anything if
# we are in eval mode.
#
# Some model does forward pass multiple times, we need to register the
# pre-backward hook on every output since the last output's hook has to
# fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run``
# to prevent repeated overhead from multiple hook callbacks.
outputs = self._register_pre_backward_hooks(outputs)
# Done with a forward pass.
self.training_state = TrainingState.IDLE
# Only need to clear cache during forward. During backward, the cache is not used.
# TODO (Min): Future PyTorch versions may provide a way to completely disable this
# cache. Update this when that's available.
if self.clear_autocast_cache:
torch.clear_autocast_cache()
self._free_ssd_offload()
return outputs