def scale_loss()

in apex/apex/amp/handle.py [0:0]


def scale_loss(loss,
               optimizers,
               loss_id=0,
               model=None,
               delay_unscale=False,
               delay_overflow_check=False):
    """
    On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.
    ``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::

        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()

    On context manager exit (if ``delay_unscale=False``), the gradients are checked for infs/NaNs
    and unscaled, so that ``optimizer.step()`` can be called.

    .. note::
        If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and
        can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``)
        any FP16 gradients are copied to FP32 master gradients before being unscaled.
        ``optimizer.step()`` will then apply the unscaled master gradients to the master params.

    .. warning::
        If Amp is using explicit FP32 master params, only the FP32 master gradients will be
        unscaled.  The direct ``.grad`` attributes of any FP16
        model params will remain scaled after context manager exit.
        This subtlety affects gradient clipping.  See "Gradient clipping" under
        `Advanced Amp Usage`_ for best practices.

    Args:
        loss(Tensor):  Typically a scalar Tensor. The ``scaled_loss`` that the context
            manager yields is simply ``loss.float()*loss_scale``, so in principle
            ``loss`` could have more than one element, as long as you call
            ``backward()`` on ``scaled_loss`` appropriately within the context manager body.
        optimizers:  All optimizer(s) for which the current backward pass is creating gradients.
            Must be an optimizer or list of optimizers returned from an earlier call
            to ``amp.initialize``.  For example use with multiple optimizers, see
            "Multiple models/optimizers/losses" under `Advanced Amp Usage`_.
        loss_id(int, optional, default=0):  When used in conjunction with the ``num_losses`` argument
            to ``amp.initialize``, enables Amp to use a different loss scale per loss.  ``loss_id``
            must be an integer between 0 and ``num_losses`` that tells Amp which loss is
            being used for the current backward pass.  See "Multiple models/optimizers/losses"
            under `Advanced Amp Usage`_ for examples.  If ``loss_id`` is left unspecified, Amp
            will use the default global loss scaler for this backward pass.
        model(torch.nn.Module, optional, default=None):  Currently unused, reserved to enable future
            optimizations.
        delay_unscale(bool, optional, default=False):  ``delay_unscale`` is never necessary, and
            the default value of ``False`` is strongly recommended.
            If ``True``, Amp will not unscale the gradients or perform model->master
            gradient copies on context manager exit.
            ``delay_unscale=True`` is a minor ninja performance optimization and can result
            in weird gotchas (especially with multiple models/optimizers/losses),
            so only use it if you know what you're doing.
            "Gradient accumulation across iterations" under `Advanced Amp Usage`_
            illustrates a situation where this CAN (but does not need to) be used.

    .. warning::
        If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be
        called yet after context manager exit, and must wait for another, later backward context
        manager invocation with ``delay_unscale`` left to False.

    .. _`Advanced Amp Usage`:
        https://nvidia.github.io/apex/advanced.html
    """
    if not hasattr(_amp_state, "opt_properties"):
        raise RuntimeError("Invoked 'with amp.scale_loss`, but internal Amp state has not been initialized.  "
                           "model, optimizer = amp.initialize(model, optimizer, opt_level=...) must be called "
                           "before `with amp.scale_loss`.")

    if not _amp_state.opt_properties.enabled:
        yield loss
        return

    if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC):
        optimizers = [optimizers]

    # this is what happens when i have to support tools from different sources under the same API...
    # TODO:  Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
    if isinstance(optimizers, FP16_Optimizer_for_fused):
        loss_scale = optimizers.cur_scale
    else:
        loss_scaler = _amp_state.loss_scalers[loss_id]
        loss_scale = loss_scaler.loss_scale()

    if ((not _amp_state.opt_properties.master_weights)
        and (not loss_scaler.dynamic)
        and loss_scale == 1.0):
        yield loss.float()
        # Needing to drop the cache here as well is an ugly gotcha.
        # But for now I think it's necessary to short-circuit.
        # Probably ok to skip this if not delay_unscale
        if _amp_state.opt_properties.patch_torch_functions:
            _amp_state.handle._clear_cache()
        return

    if not delay_unscale:
        if isinstance(optimizers, list):
            for optimizer in optimizers:
                if not optimizer._amp_stash.params_have_scaled_gradients:
                    optimizer._prepare_amp_backward()

    yield (loss.float())*loss_scale

    if delay_unscale:
        for optimizer in optimizers:
            optimizer._amp_stash.params_have_scaled_gradients = True
    else:
        # FusedAdam and FusedSGD will take care of unscaling as part of their step() methods.
        if not isinstance(optimizers, FP16_Optimizer_for_fused):
            loss_scaler.clear_overflow_state()
            for optimizer in optimizers:
                optimizer._post_amp_backward(loss_scaler)
                optimizer._amp_stash.params_have_scaled_gradients = False
            # For future fused optimizers that enable sync-free dynamic loss scaling,
            # should_skip will always be False.
            should_skip = False if delay_overflow_check else loss_scaler.update_scale()
            if should_skip:
                for optimizer in optimizers:
                    if not optimizer._amp_stash.already_patched:
                        # Close on loss_scaler and loss_id as well, to be safe.  Probably not
                        # necessary because amp.scale_loss is already creating a temporary scope.
                        def patch_step(opt, loss_scaler, loss_id):
                            opt_step = opt.step
                            def skip_step(closure=None):
                                if closure is not None:
                                    raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
                                maybe_print(("Gradient overflow.  Skipping step, loss scaler " +
                                             "{} reducing loss scale to {}").format(loss_id,
                                             loss_scaler.loss_scale()))
                                if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
                                    # Clear the master grads that wouldn't be zeroed by model.zero_grad()
                                    for param in opt._amp_stash.all_fp32_from_fp16_params:
                                        param.grad = None
                                opt.step = opt_step
                                opt._amp_stash.already_patched = False
                            return skip_step
                        optimizer.step = patch_step(optimizer, loss_scaler, loss_id)
                        optimizer._amp_stash.already_patched = True

    # Probably ok to skip this if not delay_unscale
    if _amp_state.opt_properties.patch_torch_functions:
        _amp_state.handle._clear_cache()