optimum/quanto/tensor/weights/qbytes.py (226 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast from typing import Optional import torch from torch.autograd import Function from ...library import is_extension_available from ..function import QuantizedLinearFunction from ..qbytes import QBytesTensor from ..qtensor import qfallback from ..qtype import qtype, qtypes __all__ = ["WeightQBytesTensor"] class WeightQBytesQuantizer(Function): @staticmethod def forward( ctx, base: torch.Tensor, qtype: qtype, axis: int, scale: torch.Tensor, activation_qtype: qtype, optimized: bool ) -> torch.Tensor: if qtype.bits != 8: raise ValueError("QBytesTensor can only be of 8-bit qtype") data = torch.ops.quanto.quantize_symmetric(base, dtype=qtype.dtype, axis=axis, scale=scale) # The instantiation of the quantized tensor must happen within the context of the Function # for the autograd magic to work. if optimized: return WeightQBytesTensor.create( qtype, axis, size=base.size(), stride=base.stride(), data=data, scale=scale, activation_qtype=activation_qtype, ) return WeightQBytesTensor( qtype, axis, size=base.size(), stride=base.stride(), data=data, scale=scale, activation_qtype=activation_qtype, ) @staticmethod def backward(ctx, gO): # For autograd, quantization is a no-op return gO, None, None, None, None, None, None class WeightQBytesLinearFunction(QuantizedLinearFunction): @staticmethod def forward(ctx, input, other, bias=None): ctx.save_for_backward(input, other) if isinstance(input, QBytesTensor): output = torch.ops.quanto.qbytes_mm(input._data, other._data, input._scale * other._scale) else: in_features = input.shape[-1] out_features = other.shape[0] output_shape = input.shape[:-1] + (out_features,) output = torch.ops.quanto.qbytes_mm(input.reshape(-1, in_features), other._data, other._scale) output = output.reshape(output_shape) if bias is not None: output = output + bias return output class WeightQBytesTensor(QBytesTensor): @staticmethod def create( qtype, axis, size, stride, data, scale, activation_qtype: Optional[qtype] = None, requires_grad=False, ): """Factory method to create a QBytesTensor This selects the most appropriate QBytesTensor based on the configuration. Args: axis (`int`): The axis that is preserved by quantization (usually zero for linear weights). size (): The Tensor size. stride(): The Tensor stride. data (`torch.Tensor`): The tensor data, either as a raw uint8 torch.Tensor or as a PackedTensor. scale (`torch.Tensor`): The floating point scale expressed as a torch.Tensor. activation_qtype (`qtype`, defaults to `None`): The qtype used for the activations. If one needs to use a different tensor subclass e.g. for weights depending on the activations qtype, this argument must be specified accordingly when calling `QBytesTensor.create`. requires_grad (`bool`): If the Tensor must be receive a gradient or not. Returns: a `QBytesTensor` (can be a subclass). """ from .marlin import MarlinF8QBytesTensor if ( qtype == qtypes["qfloat8_e4m3fn"] and activation_qtype is None and scale.dtype in [torch.float16, torch.bfloat16] and len(size) == 2 and (data.device.type == "cuda" and torch.version.cuda) and axis == 0 and torch.cuda.get_device_capability(data.device)[0] >= 8 and is_extension_available("quanto_cuda") ): out_features, in_features = size if ( in_features >= 64 and out_features >= 64 and ( (in_features % 64 == 0 and out_features % 128 == 0) or (in_features % 128 == 0 and out_features % 64 == 0) ) ): return MarlinF8QBytesTensor(qtype, axis, size, stride, data, scale, requires_grad) return WeightQBytesTensor(qtype, axis, size, stride, data, scale, activation_qtype, requires_grad) @staticmethod def __new__(cls, qtype, axis, size, stride, data, scale, activation_qtype, requires_grad=False): assert data.device == scale.device return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad ) def __init__(self, qtype, axis, size, stride, data, scale, activation_qtype, requires_grad=False): super().__init__(qtype, axis, size, stride, data, scale, requires_grad=requires_grad) self.activation_qtype = activation_qtype @classmethod def quantize( cls, base: torch.Tensor, qtype: qtype, axis: int, scale: torch.Tensor, activation_qtype: Optional[qtype] = None, optimized: Optional[bool] = True, ) -> torch.Tensor: return WeightQBytesQuantizer.apply(base, qtype, axis, scale, activation_qtype, optimized) @staticmethod def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride, activation_qtype, missing_keys): inner_tensors_dict = {} missing = False for name in ["_data", "_scale"]: if prefix + name not in state_dict: missing_keys.append(prefix + name) missing = True else: inner_tensors_dict[name] = state_dict.pop(prefix + name) if missing: # could not deserialize because of missing keys return None meta = { "qtype": qtype.name, "axis": str(axis), "size": str(list(size)), "stride": str(list(stride)), "activation_qtype": "none" if activation_qtype is None else activation_qtype.name, } return WeightQBytesTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None) def optimize(self): """Allows to convert an existing WeightQBytesTensor to an optimized subclass This is used in particular after reloading a serialized WeightQBytesTensor (which is always saved using the kernel-agnostic packing). """ if type(self) is not WeightQBytesTensor: return self # Call dedicated helper to select the best subclass for this device return WeightQBytesTensor.create( self.qtype, self.axis, self.size(), self.stride(), self._data, self._scale, self.activation_qtype, self.requires_grad, ) def save_to_state_dict(self, destination, prefix, keep_vars): if type(self) is WeightQBytesTensor: super().save_to_state_dict(destination, prefix, keep_vars) else: # Convert back subclass before serializing self.weight_qbytes_tensor().save_to_state_dict(destination, prefix, keep_vars) def weight_qbytes_tensor(self): """Convert back a subclass to a WeightQBytesTensor This is required to make sure only standard packing is used when serializing. """ raise NotImplementedError def __tensor_flatten__(self): inner_tensors = ["_data", "_scale"] meta = { "qtype": self._qtype.name, "axis": str(self._axis), "size": str(list(self.size())), "stride": str(list(self.stride())), "activation_qtype": "none" if self.activation_qtype is None else self.activation_qtype.name, } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 2 assert len(meta) == 5 data, scale = inner_tensors["_data"], inner_tensors["_scale"] # Meta should only contain strings, AST compatible except qtype qtype = qtypes[meta["qtype"]] axis = ast.literal_eval(meta["axis"]) size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) activation_qtype = None if meta["activation_qtype"] == "none" else qtypes[meta["activation_qtype"]] return WeightQBytesTensor(qtype, axis, size, stride, data, scale, activation_qtype) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): """Dispatch torch functions applied on this subtensor This method is called whenever a torch function (such as `torch.nn.functional.linear`) is called with at least one parameter coresponding to this subtensor: - if a quantized implementation exists for the selected function, it is called, - otherwise, the original implementation is called, deactivating further functional dispatch. During the execution of the standard torch function, a second-level of dispatch will happen, but this time directly on individual torch Tensor operations (mainly ATEN). """ kwargs = kwargs or {} if func is torch.nn.functional.linear: def qlinear(input, other, bias=None): return WeightQBytesLinearFunction.apply(input, other, bias) return qlinear(*args, **kwargs) elif func is torch.equal: input, other = args return input.equal(other) # Defer to operations dispatcher with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): # Do not use directly op, but rather its overload op = op.overloadpacket if op is torch.ops.aten.detach: t = args[0] # Detach is required when copying and deserializing inner_tensor_names, meta = t.__tensor_flatten__() # Detach inner tensors detached_tensors = {} for inner_name in inner_tensor_names: detached_tensors[inner_name] = op(getattr(t, inner_name)) return cls.__tensor_unflatten__(detached_tensors, meta, t.size(), t.stride()) elif op in [torch.ops.aten._to_copy, torch.ops.aten.to]: t = args[0] dtype = kwargs.pop("dtype", t.dtype) device = kwargs.pop("device", t.device) if dtype != t.dtype: raise ValueError("The dtype of a weights Tensor cannot be changed") if type(t) is not WeightQBytesTensor and t.device.type != device.type: # Before moving to another device type, convert back to a WeightQBytesTensor t = t.weight_qbytes_tensor() out_data = op(t._data, device=device, **kwargs) out_scale = op(t._scale, device=device, **kwargs) return WeightQBytesTensor.create( t.qtype, t.axis, t.size(), t.stride(), out_data, out_scale, activation_qtype=t.activation_qtype, requires_grad=t.requires_grad, ) elif op is torch.ops.aten.t and cls is WeightQBytesTensor: t = args[0] out_data = op(t._data) out_scale = t._scale out_axis = t.axis # Manually reverse size and stride because we cannot trust the out_data shape dim0, dim1 = t.size() out_size = torch.Size([dim1, dim0]) out_stride = t.stride()[::-1] if t.axis is not None: # We need to transpose also the scale out_scale = op(out_scale) out_axis = 0 if out_axis == -1 else -1 return WeightQBytesTensor(t.qtype, out_axis, out_size, out_stride, out_data, out_scale, t.activation_qtype) # No dispatch available: qfallback kwargs = kwargs or {} return qfallback(op, *args, **kwargs)