def scale()

in fairscale/optim/grad_scaler.py [0:0]


    def scale(self, outputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, abc.Iterable]:
        """
        Multiplies ('scales') a tensor or list of tensors by the scale factor.

        Returns scaled outputs.  If this instance of :class:`GradScaler` is not enabled, outputs are returned
        unmodified.

        Args:
            outputs (Tensor or iterable of Tensors):  Outputs to scale.
        """
        if not self._enabled:
            return outputs

        # Short-circuit for the common case.
        if isinstance(outputs, torch.Tensor):
            assert outputs.is_cuda or outputs.device.type == "xla" or outputs.device.type == "cpu"
            if self._scale is None:
                self._lazy_init_scale_growth_tracker(outputs.device)  # type: ignore
            assert self._scale is not None
            return outputs * self._scale.to(device=outputs.device, non_blocking=True)

        # Invoke the more complex machinery only if we're treating multiple outputs.
        stash: List[_GeneralMultiDeviceReplicator] = []  # holds a reference that can be overwritten by apply_scale

        def apply_scale(val: Union[torch.Tensor, abc.Iterable]) -> Union[torch.Tensor, abc.Iterable]:
            if isinstance(val, torch.Tensor):
                assert val.is_cuda or val.device.type == "xla" or val.device.type == "cpu"
                if len(stash) == 0:
                    if self._scale is None:
                        self._lazy_init_scale_growth_tracker(val.device)  # type: ignore
                    assert self._scale is not None
                    stash.append(_GeneralMultiDeviceReplicator(self._scale))
                return val * stash[0].get(val.device)
            elif isinstance(val, abc.Iterable):
                iterable = map(apply_scale, val)
                if isinstance(val, list) or isinstance(val, tuple):
                    return type(val)(iterable)
                else:
                    return iterable
            else:
                raise ValueError("outputs must be a Tensor or an iterable of Tensors")

        return apply_scale(outputs)