in apex/apex/amp/handle.py [0:0]
def scale_loss(loss,
optimizers,
loss_id=0,
model=None,
delay_unscale=False,
delay_overflow_check=False):
"""
On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.
``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
On context manager exit (if ``delay_unscale=False``), the gradients are checked for infs/NaNs
and unscaled, so that ``optimizer.step()`` can be called.
.. note::
If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and
can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``)
any FP16 gradients are copied to FP32 master gradients before being unscaled.
``optimizer.step()`` will then apply the unscaled master gradients to the master params.
.. warning::
If Amp is using explicit FP32 master params, only the FP32 master gradients will be
unscaled. The direct ``.grad`` attributes of any FP16
model params will remain scaled after context manager exit.
This subtlety affects gradient clipping. See "Gradient clipping" under
`Advanced Amp Usage`_ for best practices.
Args:
loss(Tensor): Typically a scalar Tensor. The ``scaled_loss`` that the context
manager yields is simply ``loss.float()*loss_scale``, so in principle
``loss`` could have more than one element, as long as you call
``backward()`` on ``scaled_loss`` appropriately within the context manager body.
optimizers: All optimizer(s) for which the current backward pass is creating gradients.
Must be an optimizer or list of optimizers returned from an earlier call
to ``amp.initialize``. For example use with multiple optimizers, see
"Multiple models/optimizers/losses" under `Advanced Amp Usage`_.
loss_id(int, optional, default=0): When used in conjunction with the ``num_losses`` argument
to ``amp.initialize``, enables Amp to use a different loss scale per loss. ``loss_id``
must be an integer between 0 and ``num_losses`` that tells Amp which loss is
being used for the current backward pass. See "Multiple models/optimizers/losses"
under `Advanced Amp Usage`_ for examples. If ``loss_id`` is left unspecified, Amp
will use the default global loss scaler for this backward pass.
model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future
optimizations.
delay_unscale(bool, optional, default=False): ``delay_unscale`` is never necessary, and
the default value of ``False`` is strongly recommended.
If ``True``, Amp will not unscale the gradients or perform model->master
gradient copies on context manager exit.
``delay_unscale=True`` is a minor ninja performance optimization and can result
in weird gotchas (especially with multiple models/optimizers/losses),
so only use it if you know what you're doing.
"Gradient accumulation across iterations" under `Advanced Amp Usage`_
illustrates a situation where this CAN (but does not need to) be used.
.. warning::
If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be
called yet after context manager exit, and must wait for another, later backward context
manager invocation with ``delay_unscale`` left to False.
.. _`Advanced Amp Usage`:
https://nvidia.github.io/apex/advanced.html
"""
if not hasattr(_amp_state, "opt_properties"):
raise RuntimeError("Invoked 'with amp.scale_loss`, but internal Amp state has not been initialized. "
"model, optimizer = amp.initialize(model, optimizer, opt_level=...) must be called "
"before `with amp.scale_loss`.")
if not _amp_state.opt_properties.enabled:
yield loss
return
if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC):
optimizers = [optimizers]
# this is what happens when i have to support tools from different sources under the same API...
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
if isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scale = optimizers.cur_scale
else:
loss_scaler = _amp_state.loss_scalers[loss_id]
loss_scale = loss_scaler.loss_scale()
if ((not _amp_state.opt_properties.master_weights)
and (not loss_scaler.dynamic)
and loss_scale == 1.0):
yield loss.float()
# Needing to drop the cache here as well is an ugly gotcha.
# But for now I think it's necessary to short-circuit.
# Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache()
return
if not delay_unscale:
if isinstance(optimizers, list):
for optimizer in optimizers:
if not optimizer._amp_stash.params_have_scaled_gradients:
optimizer._prepare_amp_backward()
yield (loss.float())*loss_scale
if delay_unscale:
for optimizer in optimizers:
optimizer._amp_stash.params_have_scaled_gradients = True
else:
# FusedAdam and FusedSGD will take care of unscaling as part of their step() methods.
if not isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scaler.clear_overflow_state()
for optimizer in optimizers:
optimizer._post_amp_backward(loss_scaler)
optimizer._amp_stash.params_have_scaled_gradients = False
# For future fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False.
should_skip = False if delay_overflow_check else loss_scaler.update_scale()
if should_skip:
for optimizer in optimizers:
if not optimizer._amp_stash.already_patched:
# Close on loss_scaler and loss_id as well, to be safe. Probably not
# necessary because amp.scale_loss is already creating a temporary scope.
def patch_step(opt, loss_scaler, loss_id):
opt_step = opt.step
def skip_step(closure=None):
if closure is not None:
raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
maybe_print(("Gradient overflow. Skipping step, loss scaler " +
"{} reducing loss scale to {}").format(loss_id,
loss_scaler.loss_scale()))
if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in opt._amp_stash.all_fp32_from_fp16_params:
param.grad = None
opt.step = opt_step
opt._amp_stash.already_patched = False
return skip_step
optimizer.step = patch_step(optimizer, loss_scaler, loss_id)
optimizer._amp_stash.already_patched = True
# Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache()