optimum/quanto/tensor/activations/qbytes_ops.py (183 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from functools import partial
from typing import Callable, List
import torch
from ..core import dtype_info
from ..qtensor import QTensor, qfallback
from ..qtype import qint8
from .qbytes import ActivationQBytesTensor
from .quantization import quantize_activation
__all__ = ["get_qbytestensor_op_dispatch", "register_qbytestensor_op"]
_QBYTESTENSOR_OP_TABLE = {}
def register_qbytestensor_op(aten_ops: List[Callable]):
"""
Used for registering a new __torch_dispatch__ aten operation to QBytesTensor.
The code to register a new operation looks like:
@register_qbytestensor_op(list_of_ops)
def foo(op, *args, **kwargs):
<implementation>
"""
def wrapper(op):
for aten_op in aten_ops:
_QBYTESTENSOR_OP_TABLE[aten_op] = partial(op, aten_op)
return wrapper
def get_qbytestensor_op_dispatch(aten_op):
return _QBYTESTENSOR_OP_TABLE.get(aten_op, None)
def is_scalar(t):
return isinstance(t, numbers.Number) or type(t) is torch.Tensor and len(t.shape) == 0
@register_qbytestensor_op([torch.ops.aten._to_copy, torch.ops.aten.to])
def _to_copy(op, t, dtype=None, **kwargs):
# For data, ignore dtype and use the inner type instead
out_data = op(t._data, dtype=t._data.dtype, **kwargs)
# Apply the new dtype on the scale only
out_scale = op(t._scale, dtype=dtype, **kwargs)
return ActivationQBytesTensor(t.qtype, t.size(), t.stride(), out_data, out_scale)
@register_qbytestensor_op([torch.ops.aten.detach])
def detach(op, t):
# Detach both data and scale
out_data = op(t._data)
out_scale = op(t._scale)
return ActivationQBytesTensor(t.qtype, t.size(), t.stride(), out_data, out_scale)
@register_qbytestensor_op([torch.ops.aten.cat])
def cat(op, inputs, dim=0):
if len(inputs) == 2:
t1, t2 = inputs
# Only quantized tensors with identical scalar scales can be concatenated
if (
isinstance(t1, ActivationQBytesTensor)
and isinstance(t2, ActivationQBytesTensor)
and torch.equal(t1._scale, t2._scale)
and t1.qtype == t2.qtype
):
if t1.qtype.is_floating_point or t2.qtype.is_floating_point:
# Cat is not supported for float8
return qfallback(op, inputs, dim)
out_data = op([t1._data, t2._data], dim)
return ActivationQBytesTensor(t1.qtype, out_data.size(), out_data.stride(), out_data, t1._scale)
return qfallback(op, inputs, dim)
@register_qbytestensor_op([torch.ops.aten.lt])
def lt(op, input, other):
# Only quantized tensors with identical scales can be compared
if (
isinstance(input, ActivationQBytesTensor)
and isinstance(other, ActivationQBytesTensor)
and torch.equal(input._scale, other._scale)
):
return op(input._data, other._data)
return qfallback(op, input, other)
@register_qbytestensor_op([torch.ops.aten.clone])
def clone(op, t, memory_format=torch.preserve_format):
# We need to restore the data original shape before cloning to get the correct strides
data_shape = t._data.shape
out_data = t._data.reshape(t.shape)
out_data = op(t._data, memory_format=memory_format)
out_stride = out_data.stride()
out_data = out_data.reshape(data_shape)
out_scale = op(t._scale, memory_format=memory_format)
return ActivationQBytesTensor(t.qtype, t.size(), out_stride, out_data, out_scale)
@register_qbytestensor_op([torch.ops.aten.copy_])
def copy_(op, dest, src):
assert dest.qtype == src.qtype
dest._data = op(dest._data, src._data)
dest._scale = op(dest._scale, src._scale)
return dest
@register_qbytestensor_op([torch.ops.aten.div])
def div(op, input, other):
if not is_scalar(other):
return op(input.dequantize(), other)
# We just divide the scale
return ActivationQBytesTensor(input.qtype, input.size(), input.stride(), input._data, op(input._scale, other))
@register_qbytestensor_op([torch.ops.aten.neg])
def neg(op, input, *args, **kwargs):
if input.qtype.is_floating_point:
# Neg is not supported for float8
return op(input.dequantize(), *args, **kwargs)
out_data = op(input._data, *args, **kwargs)
return ActivationQBytesTensor(input.qtype, input.size(), input.stride(), out_data, input._scale)
@register_qbytestensor_op(
[
torch.ops.aten.expand,
torch.ops.aten.permute,
torch.ops.aten.select,
torch.ops.aten.slice,
torch.ops.aten.unsqueeze,
]
)
def unary_type_agnostic_op(op, input, *args, **kwargs):
if input.axis is not None:
return op(input.dequantize(), *args, **kwargs)
# When quantization is per-tensor, these operations can be transparently applied
# without modifying the scale.
out_data = op(input._data, *args, **kwargs)
return ActivationQBytesTensor(input.qtype, out_data.size(), out_data.stride(), out_data, input._scale)
@register_qbytestensor_op([torch.ops.aten.is_same_size])
def is_same_size(op, input, other):
a = input._data if isinstance(input, ActivationQBytesTensor) else input
b = other._data if isinstance(other, ActivationQBytesTensor) else other
return op(a, b)
def cannot_mm(t: QTensor):
"""True if the QTensor data cannot be passed to an mm op"""
return t.axis is not None and t.size() != t._data.size()
@register_qbytestensor_op([torch.ops.aten.bmm])
def bmm(op, input, other):
if not isinstance(input, ActivationQBytesTensor):
return op(input, other.dequantize())
if not isinstance(other, QTensor) or input.axis is not None:
return op(input.dequantize(), other)
if input.qtype != qint8 or other.qtype != qint8 or cannot_mm(other):
return qfallback(op, input, other)
# Cast data to float32 and do the operation
out_data = op(input._data.to(torch.float32), other._data.to(torch.float32))
out_scale = (input._scale * other._scale).to(torch.float32)
return (out_data * out_scale).to(input._scale.dtype)
@register_qbytestensor_op([torch.ops.aten.mul])
def mul(op, input, other):
# If one of the multiplicands is a scalar, just multiply the scale
if is_scalar(input):
return ActivationQBytesTensor(other.qtype, other.size(), other.stride(), other._data, input * other._scale)
if is_scalar(other):
return ActivationQBytesTensor(input.qtype, input.size(), input.stride(), input._data, other * input._scale)
return qfallback(op, input, other)
@register_qbytestensor_op([torch.ops.aten.relu])
def relu(op, input):
if input.qtype.is_floating_point:
# Relu is not supported for float8 types
return qfallback(op, input)
out_data = op(input._data)
return ActivationQBytesTensor(input.qtype, input.size(), input.stride(), out_data, input._scale)
@register_qbytestensor_op([torch.ops.aten._softmax])
def _softmax(op, input, dim, half_to_float):
# Softmax must be performed in float
float_data = op(input.dequantize(), dim, half_to_float)
# Since softmax is normalized, we know the optimal scale
out_scale = torch.tensor(1 / dtype_info(input.qtype.dtype).max, dtype=input._scale.dtype).to(input.device)
return quantize_activation(float_data, qtype=input.qtype, scale=out_scale)
@register_qbytestensor_op([torch.ops.aten.stack])
def stack(op, inputs, dim=0):
if len(inputs) == 2:
t1, t2 = inputs
# Only quantized tensors with identical scales can be stacked
if (
isinstance(t1, ActivationQBytesTensor)
and isinstance(t2, ActivationQBytesTensor)
and t1.axis is None
and t2.axis is None
and torch.equal(t1._scale, t2._scale)
and t1.qtype == t2.qtype
):
out_data = op([t1._data, t2._data], dim)
return ActivationQBytesTensor(t1.qtype, out_data.size(), out_data.stride(), out_data, t1._scale)
return qfallback(inputs, dim)
@register_qbytestensor_op([torch.ops.aten.split])
def split(op, input, *args, **kwargs):
if input.axis is not None:
return qfallback(op, input, *args, **kwargs)
out_datas = op(input._data, *args, **kwargs)
return [
ActivationQBytesTensor(input.qtype, input.size(), input.stride(), out_data, input._scale)
for out_data in out_datas
]
@register_qbytestensor_op([torch.ops.aten.transpose])
def transpose(op, input, *args):
out_data = op(input._data, *args)
out_size = out_data.size()
out_stride = out_data.stride()
out_scale = input._scale
return ActivationQBytesTensor(input.qtype, out_size, out_stride, out_data, out_scale)
@register_qbytestensor_op([torch.ops.aten.t])
def transpose2d(op, input):
out_data = op(input._data)
out_scale = input._scale
# Manually reverse size and stride because we cannot trust the out_data shape
dim0, dim1 = input.size()
out_size = torch.Size([dim1, dim0])
out_stride = input.stride()[::-1]
return ActivationQBytesTensor(input.qtype, out_size, out_stride, out_data, out_scale)
@register_qbytestensor_op([torch.ops.aten.view, torch.ops.aten._unsafe_view])
def view(op, input, *shape):
if input.axis is None:
# The view is transparent for QTensor with scalar scales
out_data = op(input._data, *shape)
return ActivationQBytesTensor(input.qtype, out_data.size(), out_data.stride(), out_data, input._scale)
return qfallback(op, input, *shape)
@register_qbytestensor_op([torch.ops.aten.where])
def where(op, condition, input, other):
if isinstance(condition, QTensor) or isinstance(other, QTensor):
raise NotImplementedError
float_data = op(condition, input.dequantize(), other)
if input.axis is None:
# We requantize with the input scale
return quantize_activation(float_data, qtype=input.qtype, scale=input._scale)
return float_data