optimum/quanto/tensor/qtensor.py (49 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.utils import _pytree as pytree
__all__ = ["QTensor", "qfallback"]
def qfallback(callable, *args, **kwargs):
"""Fallback method for QTensor inputs.
When a torch function or an aten operation is not supported for the specified
QTensor arguments, each QTensor arg or kwarg is dequantized to a torch.Tensor
before calling the target function or op.
"""
args, kwargs = pytree.tree_map_only(QTensor, lambda x: x.dequantize(), (args, kwargs or {}))
return callable(*args, **kwargs)
class QTensor(torch.Tensor):
def __init__(self, qtype, axis):
self._qtype = qtype
self._axis = axis
def dequantize(self):
raise NotImplementedError
def save_to_state_dict(self, destination, prefix, keep_vars):
def serialize_tensor_subclass(t, destination, prefix, keep_vars):
inner_tensors, meta = t.__tensor_flatten__()
for name in inner_tensors:
inner_tensor = getattr(t, name)
if type(inner_tensor) is torch.Tensor:
# Leaf Tensor, we can serialize it
destination[prefix + name] = inner_tensor if keep_vars else inner_tensor.detach()
else:
# Flatten also this inner Tensor
serialize_tensor_subclass(inner_tensor, destination, prefix + name + ".", keep_vars)
# Recursively flatten QTensor into individual tensors
serialize_tensor_subclass(self, destination, prefix, keep_vars)
@property
def axis(self):
return self._axis
@property
def qtype(self):
return self._qtype
def numpy(self):
return self.dequantize().cpu().numpy()
def equal(self, other):
if type(self) is not type(other):
return False
self_tensors, self_meta = self.__tensor_flatten__()
_, other_meta = other.__tensor_flatten__()
for name, value in self_meta.items():
if other_meta[name] != value:
return False
for name in self_tensors:
self_t = getattr(self, name)
other_t = getattr(other, name)
if self_t.device.type == "cpu" and self_t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
# torch.equal is not implemented on CPU for float8 types
if self_t.dtype != other_t.dtype:
return False
if not torch.equal(self_t.to(torch.float32), other_t.to(torch.float32)):
return False
elif not torch.equal(self_t, other_t):
return False
return True