optimum/quanto/tensor/weights/tinygemm/packed.py (83 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
from copy import copy
import torch
from torch.utils import _pytree as pytree
__all__ = ["TinyGemmPackedTensor"]
class TinyGemmPackedTensor(torch.Tensor):
@staticmethod
def __new__(cls, data, size, stride, requires_grad=False):
# TinyGemmPackedTensor represents uint8 data and can therefore NEVER require gradient
assert requires_grad is False
return torch.Tensor._make_wrapper_subclass(
cls, size, strides=stride, dtype=torch.uint8, device=data.device, requires_grad=requires_grad
)
def __init__(self, data, size, stride, requires_grad=False):
self._data = data
def __repr__(self):
return f"TinyGemmPackedTensor({self._data})"
@classmethod
def pack(cls, t):
"""Pack a torch.Tensor for tinygemm kernel
This packs uint4 weights in an int32 tensor as expected by the torch tinygemm mixed mm kernel
Args:
t (`torch.Tensor`):
The un-packed `torch.uint8` tensor
Returns:
A `TinyGemmPackedTensor`.
"""
inner_ktiles = 2
t = t.to(torch.int32).contiguous()
if t.device.type == "cpu":
data = torch._convert_weight_to_int4pack_for_cpu(t, innerKTiles=inner_ktiles)
elif t.device.type == "xpu":
t_uint8 = (t[::, 1::2] << 4 | t[::, ::2]).to(torch.uint8)
data = torch._convert_weight_to_int4pack(t_uint8, innerKTiles=inner_ktiles)
else:
t_uint8 = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8)
data = torch._convert_weight_to_int4pack(t_uint8, innerKTiles=inner_ktiles)
# We need to store size and stride to make sure the unpacked data has the correct shape
return TinyGemmPackedTensor(data, t.size(), t.stride())
def unpack(self):
"""Unpack the packed tensor to a torch.Tensor
Packing is device specific and implemented in undocumented dedicated kernels
that are synchronized with the corresponding matrix multiplication operation.
Instead of implementing a dedicated unpacking code, we pass an identity matrix
to the mm operation with identity scale and shifts to produce the unpacked uint8 weights.
Returns:
An unpacked uint8 `torch.Tensor` expanded along the second dimension.
"""
out_features, in_features = self.size()
# We need to pass a group_size to the mm and format the scale and shift accordingly,
# although it does not modify the calculation since we use identity scales and shifts.
# We arbitrarily choose the smallest group_size to be sure it divides in_features
group_size = 32
scale_and_shift_shape = (in_features // group_size, out_features, 2)
# Initialize identity scale
id_scale_and_shift = torch.ones(scale_and_shift_shape, dtype=torch.bfloat16, device=self.device)
# Set shift to mid-point, i.e. 2 **(bits - 1)
id_scale_and_shift[:, :, 1] = 8
identity = torch.eye(in_features, dtype=torch.bfloat16, device=self.device)
if self._data.device.type == "cpu":
unpacked_data = torch._weight_int4pack_mm_for_cpu(identity, self._data, group_size, id_scale_and_shift)
else:
unpacked_data = torch._weight_int4pack_mm(identity, self._data, group_size, id_scale_and_shift)
return unpacked_data.t().to(torch.uint8)
@property
def dtype(self):
return torch.uint8
def __tensor_flatten__(self):
inner_tensors = ["_data"]
# Since meta can be used for serialization, use only AST compatible strings
meta = {
"size": str(list(self.size())),
"stride": str(self.stride()),
}
return inner_tensors, meta
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert len(inner_tensors) == 1
assert len(meta) == 2
data = inner_tensors["_data"]
# Meta should contain only AST compatible strings
size = ast.literal_eval(meta["size"])
stride = ast.literal_eval(meta["stride"])
return TinyGemmPackedTensor(data, size, stride)
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, op, types, args, kwargs=None):
# Convert back to tensor before calling any operation except detach and move
if op.overloadpacket is torch.ops.aten.detach:
t = args[0]
data = op(t._data)
return TinyGemmPackedTensor(data, t.size(), t.stride())
elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to):
t = args[0]
dtype = kwargs.get("dtype", torch.uint8)
if dtype != torch.uint8:
raise ValueError(f"TinyGemmPackedTensor are torch.uint8 only and cannot be moved to {dtype}.")
data_kwargs = copy(kwargs)
data_kwargs["dtype"] = t._data.dtype
if kwargs.get("device", t.device).type != t.device.type:
# Packing is device specific, so we need to unpack before moving
unpacked = t.unpack()
unpacked = op(unpacked, **data_kwargs)
return TinyGemmPackedTensor.pack(unpacked)
# If we stay on the same device type, just copy/move packed data
data = op(t._data, **data_kwargs)
return TinyGemmPackedTensor(data, t.size(), t.stride())
args, kwargs = pytree.tree_map_only(TinyGemmPackedTensor, lambda x: x.unpack(), (args, kwargs or {}))
return op(*args, **kwargs)
def numpy(self):
return self.unpack().cpu().numpy()