optimum/quanto/tensor/weights/qbits.py (204 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
from typing import Optional
import torch
from packaging import version
from torch.autograd import Function
from ...library import is_extension_available
from ..function import QuantizedLinearFunction
from ..grouped import grouped_shape
from ..packed import PackedTensor
from ..qbits import QBitsTensor
from ..qtensor import qfallback
from ..qtype import qint2, qint4, qtype, qtypes
__all__ = ["WeightQBitsTensor"]
class WeightsQBitsQuantizer(Function):
@staticmethod
def forward(
ctx,
base: torch.Tensor,
qtype: qtype,
axis: int,
group_size: int,
scale: torch.Tensor,
shift: torch.Tensor,
optimized: bool,
):
if qtype not in (qint2, qint4):
raise ValueError("WeightQBitsTensor can only be of qint2 or qint4 qtype")
if axis not in (0, -1):
raise ValueError("WeightQBitsTensor axis parameter must be 0 (first axis) or -1 (last axis)")
size = base.size()
stride = base.stride()
data = torch.ops.quanto.quantize_affine(
base, bits=qtype.bits, axis=axis, group_size=group_size, scale=scale, shift=shift
)
if optimized:
return WeightQBitsTensor.create(qtype, axis, group_size, size, stride, data, scale, shift)
return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift)
@staticmethod
def backward(ctx, gO):
# For autograd, quantization is a no-op
return gO, None, None, None, None, None, None
class WeightQBitsTensor(QBitsTensor):
@staticmethod
def create(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):
"""Factory method to create a WeightQBitsTensor
This selects the most appropriate WeightQBitsTensor based on the configuration.
Args:
axis (`int`):
The axis that is preserved by quantization (usually zero for linear weights).
group_size (`int`):
The group size that further splits the data elements for each index along the quantization axis.
size ():
The Tensor size.
stride():
The Tensor stride.
data (`torch.Tensor`):
The tensor data, either as a raw uint8 torch.Tensor or as a PackedTensor.
scale (`torch.Tensor`):
The floating point scale expressed as a torch.Tensor.
shift (`torch.Tensor`):
The shift expressed as a torch.Tensor. It can be either an integer representing zero
(i.e. zero-point) or a float value.
requires_grad (`bool`):
If the Tensor must be receive a gradient or not.
Returns:
a `WeightQBitsTensor` (can be a subclass).
"""
from .awq import AWQWeightQBitsTensor
from .tinygemm import TinyGemmWeightQBitsTensor
if (
qtype == qint4
and size[0] >= 128 # FIXME Workaround AWQ GEMM crash (GEMV might work for short inputs)
and scale.dtype == torch.float16
and axis == 0
and group_size == 128
and len(size) == 2
and (data.device.type == "cuda" and torch.version.cuda)
and torch.cuda.get_device_capability(data.device)[0] >= 8
and is_extension_available("quanto_cuda")
):
if type(data) is PackedTensor:
data = data.unpack()
return AWQWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)
if qtype == qint4 and scale.dtype == torch.bfloat16 and axis == 0 and group_size == 128 and len(size) == 2:
if data.device.type == "cpu" or (
(data.device.type == "cuda" and torch.version.cuda)
and version.parse(torch.version.cuda).release >= (12, 1)
and torch.cuda.get_device_capability(data.device)[0] >= 8
):
if type(data) is PackedTensor:
data = data.unpack()
return TinyGemmWeightQBitsTensor(
qtype, axis, group_size, size, stride, data, (scale, shift), requires_grad
)
return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)
@staticmethod
def __new__(cls, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):
assert data.device == scale.device
assert data.device == shift.device
return torch.Tensor._make_wrapper_subclass(
cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad
)
def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):
if type(data) is torch.Tensor:
data = PackedTensor.pack(data, qtype.bits)
super().__init__(qtype, axis, group_size, size, stride, data, scale, shift)
@classmethod
def quantize(
cls,
base: torch.Tensor,
qtype: qtype,
axis: int,
group_size: int,
scale: torch.Tensor,
shift: torch.Tensor,
optimized: Optional[bool] = True,
):
return WeightsQBitsQuantizer.apply(base, qtype, axis, group_size, scale, shift, optimized)
@staticmethod
def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stride, missing_keys):
if group_size is None:
data_size = size
data_stride = stride
else:
data_size = grouped_shape(size, axis, group_size)
assert len(data_size) == 2
# In row major, inner dimension (stride 1) is the last one
data_stride = (data_size[1], 1)
inner_tensors_dict = {
"_data": PackedTensor.load_from_state_dict(
state_dict, prefix + "_data.", qtype.bits, data_size, data_stride, missing_keys=missing_keys
)
}
missing = inner_tensors_dict["_data"] is None
for name in ["_scale", "_shift"]:
if prefix + name not in state_dict:
missing_keys.append(prefix + name)
missing = True
else:
inner_tensors_dict[name] = state_dict.pop(prefix + name)
if missing: # could not deserialize because of missing keys
return None
meta = {
"qtype": qtype.name,
"axis": str(axis),
"group_size": str(group_size),
"size": str(list(size)),
"stride": str(list(stride)),
}
return WeightQBitsTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None)
def optimize(self):
"""Allows to convert an existing WeightQBitsTensor to an optimized subclass
This is used in particular after reloading a serialized WeightQBitsTensor (which is
always saved using the kernel-agnostic packing).
"""
if type(self) is not WeightQBitsTensor:
return self
data = self._data.unpack()
# Call dedicated helper to select the best subclass for this device
return WeightQBitsTensor.create(
self.qtype,
self.axis,
self._group_size,
self.size(),
self.stride(),
data,
self._scale,
self._shift,
self.requires_grad,
)
def save_to_state_dict(self, destination, prefix, keep_vars):
if type(self) is WeightQBitsTensor:
super().save_to_state_dict(destination, prefix, keep_vars)
else:
# Convert back subclass before serializing
self.weight_qbits_tensor().save_to_state_dict(destination, prefix, keep_vars)
def weight_qbits_tensor(self):
"""Convert back a subclass to a WeightQBitsTensor
This is required to make sure only standard packing is used when serializing.
"""
raise NotImplementedError
def __tensor_flatten__(self):
inner_tensors = ["_data", "_scale", "_shift"]
# Since meta can be used for serialization, use only strings
meta = {
"qtype": self._qtype.name,
"axis": str(self._axis),
"group_size": str(self._group_size),
"size": str(list(self.size())),
"stride": str(list(self.stride())),
}
return inner_tensors, meta
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert len(inner_tensors) == 3
assert len(meta) == 5
data, scale, shift = inner_tensors["_data"], inner_tensors["_scale"], inner_tensors["_shift"]
# Meta should only contain strings, AST compatible except qtype
qtype = qtypes[meta["qtype"]]
axis = ast.literal_eval(meta["axis"])
group_size = ast.literal_eval(meta["group_size"])
size = ast.literal_eval(meta["size"])
stride = ast.literal_eval(meta["stride"])
return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
"""Dispatch torch functions applied on this subtensor
This method is called whenever a torch function (such as `torch.nn.functional.linear`)
is called with at least one parameter coresponding to this subtensor:
- if a quantized implementation exists for the selected function, it is called,
- otherwise, the original implementation is called, deactivating further functional dispatch.
During the execution of the standard torch function, a second-level of dispatch will
happen, but this time directly on individual torch Tensor operations (mainly ATEN).
"""
kwargs = kwargs or {}
if func is torch.nn.functional.linear:
def qlinear(input, other, bias=None):
return QuantizedLinearFunction.apply(input, other, bias)
return qlinear(*args, **kwargs)
elif func is torch.equal:
input, other = args
return input.equal(other)
# Defer to operations dispatcher
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
@classmethod
def __torch_dispatch__(cls, op, types, args, kwargs=None):
# Do not use directly op, but rather its overload
op = op.overloadpacket
if op is torch.ops.aten.detach:
t = args[0]
# Detach is required when copying and deserializing
inner_tensor_names, meta = t.__tensor_flatten__()
# Detach inner tensors
detached_tensors = {}
for inner_name in inner_tensor_names:
detached_tensors[inner_name] = op(getattr(t, inner_name))
return cls.__tensor_unflatten__(detached_tensors, meta, t.size(), t.stride())
elif op in [torch.ops.aten._to_copy, torch.ops.aten.to]:
t = args[0]
dtype = kwargs.pop("dtype", t.dtype)
device = kwargs.pop("device", t.device)
if dtype is not None and dtype != t.dtype:
raise ValueError("The dtype of a WeightQBitsTensor cannot be changed")
if type(t) is not WeightQBitsTensor and t.device.type != device.type:
# Before moving to another device type, convert back to a WeightQBitsTensor
t = t.weight_qbits_tensor()
scale = op(t._scale, dtype=dtype, device=device, **kwargs)
data = op(t._data, device=device, **kwargs)
shift = op(t._shift, device=device, **kwargs)
return WeightQBitsTensor.create(t._qtype, t._axis, t._group_size, t.size(), t.stride(), data, scale, shift)
# No dispatch available: qfallback
kwargs = kwargs or {}
return qfallback(op, *args, **kwargs)