in optimum/quanto/tensor/weights/qbits.py [0:0]
def create(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):
"""Factory method to create a WeightQBitsTensor
This selects the most appropriate WeightQBitsTensor based on the configuration.
Args:
axis (`int`):
The axis that is preserved by quantization (usually zero for linear weights).
group_size (`int`):
The group size that further splits the data elements for each index along the quantization axis.
size ():
The Tensor size.
stride():
The Tensor stride.
data (`torch.Tensor`):
The tensor data, either as a raw uint8 torch.Tensor or as a PackedTensor.
scale (`torch.Tensor`):
The floating point scale expressed as a torch.Tensor.
shift (`torch.Tensor`):
The shift expressed as a torch.Tensor. It can be either an integer representing zero
(i.e. zero-point) or a float value.
requires_grad (`bool`):
If the Tensor must be receive a gradient or not.
Returns:
a `WeightQBitsTensor` (can be a subclass).
"""
from .awq import AWQWeightQBitsTensor
from .tinygemm import TinyGemmWeightQBitsTensor
if (
qtype == qint4
and size[0] >= 128 # FIXME Workaround AWQ GEMM crash (GEMV might work for short inputs)
and scale.dtype == torch.float16
and axis == 0
and group_size == 128
and len(size) == 2
and (data.device.type == "cuda" and torch.version.cuda)
and torch.cuda.get_device_capability(data.device)[0] >= 8
and is_extension_available("quanto_cuda")
):
if type(data) is PackedTensor:
data = data.unpack()
return AWQWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)
if qtype == qint4 and scale.dtype == torch.bfloat16 and axis == 0 and group_size == 128 and len(size) == 2:
if data.device.type == "cpu" or (
(data.device.type == "cuda" and torch.version.cuda)
and version.parse(torch.version.cuda).release >= (12, 1)
and torch.cuda.get_device_capability(data.device)[0] >= 8
):
if type(data) is PackedTensor:
data = data.unpack()
return TinyGemmWeightQBitsTensor(
qtype, axis, group_size, size, stride, data, (scale, shift), requires_grad
)
return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)