optimum/quanto/tensor/activations/qbytes.py (57 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import torch from torch.autograd import Function from ..qbytes import QBytesTensor from ..qtensor import qfallback from ..qtype import qtype, qtypes __all__ = ["ActivationQBytesTensor"] class ActivationQBytesQuantizer(Function): @staticmethod def forward(ctx, base: torch.Tensor, qtype: qtype, scale: torch.Tensor) -> torch.Tensor: if qtype.bits != 8: raise ValueError("QBytesTensor can only be of 8-bit qtype") size = base.size() stride = base.stride() data = torch.ops.quanto.quantize_symmetric(base, dtype=qtype.dtype, axis=None, scale=scale) # The instantiation of the quantized tensor must happen within the context of the Function # for the autograd magic to work. return ActivationQBytesTensor(qtype, size, stride, data, scale) @staticmethod def backward(ctx, gO): # For autograd, quantization is a no-op return gO, None, None, None, None, None class ActivationQBytesTensor(QBytesTensor): @staticmethod def __new__(cls, qtype, size, stride, data, scale, requires_grad=False): assert data.device == scale.device return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad ) def __init__(self, qtype, size, stride, data, scale, requires_grad=False): super().__init__(qtype, None, size, stride, data, scale, requires_grad) @classmethod def quantize(cls, base: torch.Tensor, qtype: qtype, scale: torch.Tensor) -> torch.Tensor: return ActivationQBytesQuantizer.apply(base, qtype, scale) def __tensor_flatten__(self): inner_tensors = ["_data", "_scale"] meta = { "qtype": self._qtype.name, "size": str(list(self.size())), "stride": str(list(self.stride())), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 2 assert len(meta) == 3 data, scale = inner_tensors["_data"], inner_tensors["_scale"] # Meta should only contain strings, AST compatible except qtype qtype = qtypes[meta["qtype"]] size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return ActivationQBytesTensor(qtype, size, stride, data, scale) @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): from .qbytes_ops import get_qbytestensor_op_dispatch kwargs = kwargs or {} # Do not use directly op, but rather its overload op = op.overloadpacket qdispatch = get_qbytestensor_op_dispatch(op) if qdispatch is not None: return qdispatch(*args, **kwargs) # No dispatch available: qfallback return qfallback(op, *args, **kwargs)