in optimum/quanto/calibrate.py [0:0]
def calibrate_input(self, module: torch.nn.Module, input, momentum: float = 0.9):
"""Calibrate a module input scale
This is registered as a global hook that is called before any module forward pre hook.
"""
if isinstance(module, QModuleMixin) and module.activation_qtype is not None:
input = input[0]
if isinstance(input, ActivationQBytesTensor):
# Just adopt the maximum scale of the input
module.input_scale = torch.max(input._scale)
else:
# Evaluate the best scale
input_scale = absmax_scale(input, module.activation_qtype)
module.input_scale = _updated_scale(module.input_scale, input_scale, momentum)
if self.streamline and module not in self.streamline_hooks:
# Add a hook to tag the module outputs (after the module quantization hook in QModuleMixin)
self.streamline_hooks[module] = module.register_forward_hook(self.tag_outputs)
return input