optimum/quanto/nn/qmodule.py (217 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC from typing import Optional, Union import torch from ..tensor import ( AbsmaxOptimizer, ActivationQBytesTensor, MaxOptimizer, Optimizer, QTensor, SymmetricOptimizer, WeightQBitsTensor, WeightQBytesTensor, qint2, qint4, qtype, qtypes, quantize_activation, quantize_weight, ) __all__ = ["QModuleMixin", "register_qmodule", "quantize_module"] _QMODULE_TABLE = {} def register_qmodule(module_cls): """ Used for registering a new quantized module. The QModule must implement two abstract methods: - qcreate: class method to instantiate a new QModule from an nn.Module, without copying its weights, - forward: instance method for quantized inference. The code to register a new module looks like: ``` @register_qmodule(<base torch.nn.Module>) class MyQModule(QModuleMixin, <base torch.nn.Module>): <implementation> @classmethod def qcreate(cls, module: torch.nn.Module, weights: Optional[qtype], activations: Optional[qtype] = None, optimizer: Optional[Optimizer] = None): ... def forward(self, input: torch.Tensor) -> torch.Tensor: ... ``` """ def wrapper(cls): _QMODULE_TABLE[module_cls] = cls return cls return wrapper def quantize_module( module, weights: Optional[Union[qtype, str]] = None, activations: Optional[Union[qtype, str]] = None, optimizer: Optional[Optimizer] = None, ): for cls in _QMODULE_TABLE: if isinstance(module, cls): qcls = _QMODULE_TABLE[cls] return qcls.from_module(module, weights=weights, activations=activations, optimizer=optimizer) return None class QModuleMixin(ABC): def __init__( self, *args, weights: Optional[Union[qtype, str]] = None, activations: Optional[Union[qtype, str]] = None, optimizer: Optional[Optimizer] = None, quantize_input: Optional[bool] = False, device: Optional[torch.device] = None, **kwargs, ): # The tests below are meant to help people writing their own quantized Module class mro = self.__class__.__mro__ if torch.nn.Module not in mro: raise TypeError("Quantized modules must inherit from a torch.nn.Module class") if mro.index(__class__) > mro.index(torch.nn.Module): raise TypeError( "QModuleMixin must be placed before any torch.nn.Module class in quantized module inheritance." ) # This will setup the torch.nn.Module super().__init__(*args, device=device, **kwargs) if weights is not None and not isinstance(weights, qtype): weights = qtypes[weights] if activations is not None and not isinstance(activations, qtype): activations = qtypes[activations] self.weight_qtype = weights self.weight_group_size = None if self.weight_qtype in (qint2, qint4): out_features = self.weight.shape[0] in_features = self.weight.numel() // out_features group_size = 128 if in_features > group_size: while in_features % group_size != 0 and group_size > 32: group_size -= 32 if in_features % group_size == 0: self.weight_group_size = group_size self.activation_qtype = activations self._quantize_hooks = {} if activations is not None: if quantize_input: self._quantize_hooks["input"] = self.register_forward_pre_hook(self.quantize_input) self._quantize_hooks["output"] = self.register_forward_hook(self.quantize_output) if optimizer is None and self.weight_qtype is not None: optimizer = AbsmaxOptimizer() if self.weight_qtype.bits == 8 else MaxOptimizer() self.optimizer = optimizer scale_dtype = torch.float32 if self.weight is None else self.weight.dtype self.register_buffer("input_scale", torch.ones((), dtype=scale_dtype, device=device)) self.register_buffer("output_scale", torch.ones((), dtype=scale_dtype, device=device)) def disable_output_quantization(self): if "output" in self._quantize_hooks: self._quantize_hooks["output"].remove() def _save_to_state_dict(self, destination, prefix, keep_vars): if self.weight_qtype is None or not self.frozen: # Save standard weight Tensor destination[prefix + "weight"] = ( self.weight if (self.weight is None or keep_vars) else self.weight.detach() ) else: # Save QTensor using dedicated method self.weight.save_to_state_dict(destination, prefix + "weight.", keep_vars) if self.bias is not None: destination[prefix + "bias"] = self.bias if keep_vars else self.bias.detach() destination[prefix + "input_scale"] = self.input_scale if keep_vars else self.input_scale.detach() destination[prefix + "output_scale"] = self.output_scale if keep_vars else self.output_scale.detach() def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): weight_name = prefix + "weight" if self.weight_qtype is not None and weight_name not in state_dict: # The weight Tensor is not present because it is a flattened QTensor weight_prefix = weight_name + "." # note: deserialized_weight can be None if a key is missing in the state_dict if self.weight_qtype.bits == 8: deserialized_weight = WeightQBytesTensor.load_from_state_dict( state_dict, weight_prefix, qtype=self.weight_qtype, axis=0, size=self.weight.size(), stride=self.weight.stride(), activation_qtype=self.activation_qtype, missing_keys=missing_keys, ) else: deserialized_weight = WeightQBitsTensor.load_from_state_dict( state_dict, weight_prefix, qtype=self.weight_qtype, axis=0, group_size=self.weight_group_size, size=self.weight.size(), stride=self.weight.stride(), missing_keys=missing_keys, ) if deserialized_weight is not None: deserialized_weight = deserialized_weight.optimize() assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) if assign_to_params_buffers and (deserialized_weight is not None): self.weight = torch.nn.Parameter(deserialized_weight) elif deserialized_weight is not None: if type(self.weight.data) is not type(deserialized_weight): # Reloading frozen weights into unfrozen module: move to the correct device and force assignment self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device)) else: # FIXME: here we should copy frozen weights into frozen module, but this leads to grad error self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device)) super()._load_from_state_dict( state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs ) @classmethod def from_module( cls, module: torch.nn.Module, weights: Optional[qtype] = None, activations: Optional[qtype] = None, optimizer: Optional[Optimizer] = None, ): # Create the quantized module on the meta device to prevent weights intialization qmodule = cls.qcreate(module, weights, activations, optimizer, device="meta") if qmodule is None: return None # Move the quantized module to the target device, but with empty weights device = torch.device("cpu") if module.weight is None else module.weight.device qmodule = qmodule.to_empty(device=device) # Set scales that were initialized to empty values qmodule.input_scale = torch.ones_like(qmodule.input_scale) qmodule.output_scale = torch.ones_like(qmodule.output_scale) with torch.no_grad(): qmodule.weight = module.weight if module.bias is not None: qmodule.bias = module.bias return qmodule.to(device) @classmethod def qcreate( cls, module: torch.nn.Module, weights: Optional[qtype], activations: Optional[qtype] = None, optimizer: Optional[Optimizer] = None, device: Optional[torch.device] = None, ): raise NotImplementedError @property def qweight(self): """Return the module quantized weight When the module is frozen or does not quantize its weight parameter, it simply returns the weight. When the module is not frozen, this property is required to add the dynamic quantization of the weight parameter to the graph and allow gradients to be propagated to the underlying weight float values. """ if self.weight_qtype is None: # QModule that does not quantize its weights return None if isinstance(self.weight, QTensor): # Frozen QModule return self.weight # Quantize dynamically the weights per-axis if isinstance(self.optimizer, SymmetricOptimizer): scale = self.optimizer(self.weight, qtype=self.weight_qtype, axis=0) shift = None else: scale, shift = self.optimizer( self.weight, qtype=self.weight_qtype, axis=0, group_size=self.weight_group_size ) return quantize_weight( self.weight, qtype=self.weight_qtype, axis=0, scale=scale, shift=shift, group_size=self.weight_group_size, activation_qtype=self.activation_qtype, ) def qforward(self, input: torch.Tensor) -> torch.Tensor: raise NotImplementedError def quantize_input(self, module: torch.nn.Module, input: torch.Tensor) -> torch.Tensor: input = input[0] if isinstance(input, ActivationQBytesTensor): if input.qtype != self.activation_qtype: raise ValueError( "Models with heterogeneous quantized activations are not supported:" f" expected {self.activation_qtype.name} input but got {input.qtype.name} instead." ) else: input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale) return input def quantize_output( self, module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor, ) -> torch.Tensor: return quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale) def freeze(self): qweight = self.qweight if qweight is not None: # Replace float weights by quantized weights self.weight = torch.nn.Parameter(qweight) @property def frozen(self): return isinstance(self.weight, QTensor)