in fairscale/optim/grad_scaler.py [0:0]
def scale(self, outputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, abc.Iterable]:
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Args:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
assert outputs.is_cuda or outputs.device.type == "xla" or outputs.device.type == "cpu"
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device) # type: ignore
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[_GeneralMultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
def apply_scale(val: Union[torch.Tensor, abc.Iterable]) -> Union[torch.Tensor, abc.Iterable]:
if isinstance(val, torch.Tensor):
assert val.is_cuda or val.device.type == "xla" or val.device.type == "cpu"
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device) # type: ignore
assert self._scale is not None
stash.append(_GeneralMultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
elif isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, list) or isinstance(val, tuple):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)