optimum/quanto/tensor/weights/marlin/fp8/packed.py (141 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
from copy import copy
import torch
from torch.utils import _pytree as pytree
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements).
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
if fp8_tensor.shape[0] % 4 != 0:
raise ValueError(f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}")
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = torch.zeros(
fp8_tensor.shape[0] // 4,
fp8_tensor.shape[1],
dtype=torch.int32,
device=fp8_tensor.device,
)
for i in range(4):
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
return packed
def unpack_int32_to_fp8(int32_tensor: torch.Tensor) -> torch.Tensor:
"""
Reinterpret a tensor (a, b) of type int32 to a tensor (a * 4, b) of type float8_e4m3fn.
"""
bits = 8
unpacked = []
# Unpack each set of values independently
for i in range(4):
mask = 2 ** (bits * (i + 1)) - 1
tmp = (int32_tensor & mask) >> bits * i
tmp = tmp.to(torch.uint8)
unpacked.append(tmp)
# Return the concatenated unpacked tensors
unpacked = torch.cat(unpacked).view(torch.float8_e4m3fn)
return unpacked
def get_scale_perms() -> torch.Tensor:
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return torch.tensor(scale_perm_single, dtype=torch.int64)
def get_row_permutation(n_rows: int) -> torch.Tensor:
"""
Generates a tensor of shape (4 * n_rows,) giving the rows mapping to map from marlin-repacked weights to natural order.
Example: if n_rows = 8, the row mapping from natural to marlin format is
rows_idx = [0, 2, 4, 6,
16, 18, 20, 22,
8, 10, 12, 14,
24, 26, 28, 30,
1, 3, 5, 7,
17, 19, 21, 23,
9, 11, 13, 15,
25, 27, 29, 31].
"""
modulo = n_rows // 4 * 16 - 8
b = n_rows // 2
# Group by 16*k, then by 8 + 16*k
rows_idx = [(i * 16) % modulo for i in range(b)]
rows_idx[-1] = rows_idx[-2] + 16 if b > 2 else 8
rows_idx = torch.tensor(rows_idx)
# All even indexes, and then all odd indexes.
rows_idx = torch.cat((rows_idx, rows_idx + 1))
# Indexes are grouped by four, each spaced by 2.
rows_idx = torch.tile(rows_idx[:, None], (1, 4))
rows_idx = rows_idx + torch.tensor([[0, 2, 4, 6]])
rows_idx = rows_idx.reshape(-1)
# `rows_idx` holds the mapping of natural rows to marlin rows, so inverse it.
rows_idx_rev = torch.empty_like(rows_idx)
rows_idx_rev[rows_idx] = torch.arange(len(rows_idx))
return rows_idx_rev
def get_column_permutation(n_col: int) -> torch.Tensor:
"""
Gets the column mapping to map from marlin-repacked weights to natural order.
The natural order to marlin is: `8 * rest + frac` to `rest + 32 * frac`, by blocks of 256 values.
"""
tile_size = 256
n_blocks = n_col // tile_size
a = torch.arange(tile_size)
rest = a % 8
frac = a // 8
original_index = 32 * rest + frac
original_index = torch.arange(n_blocks)[:, None] * 256 + original_index
original_index = original_index.reshape(-1)
# The mapping per-column is:
#
# 64 64 64 64 64 64 64 64 64 64 64 64
# ------------------------------------------------------------------------
# | 0 1 2 3 | 0 1 2 3 | 0 1 2 3 |
# ------------------------------------------------------------------------
#
# Hence to retrieve column 0, 1, 2, 3 in order, we need to
# shuffle the blocks of 64 values.
original_index = original_index.reshape(4 * n_blocks, 64)
# Generate a shuffling as e.g. [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11] for the above.
tmp1 = torch.arange(4)
tmp1 = tmp1.repeat(n_blocks, 1).T.reshape(-1) # e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
tmp2 = torch.arange(n_blocks) * 4
tmp2 = tmp2.repeat(4) # e.g. [0, 4, 8, 0, 4, 8, 0, 4, 8, 0, 4, 8]
remap_col_index = tmp1 + tmp2
original_index = original_index[remap_col_index]
original_index = original_index.reshape(-1)
return original_index
class MarlinF8PackedTensor(torch.Tensor):
def __new__(cls, data, size, stride, requires_grad=False):
assert data.device.type == "cuda"
assert data.dtype == torch.int32
assert requires_grad is False
return torch.Tensor._make_wrapper_subclass(
cls, size, strides=stride, dtype=torch.int32, device=data.device, requires_grad=requires_grad
)
def __init__(self, data, size, stride, requires_grad=False):
self._data = data
def __repr__(self):
return f"MarlinF8PackedTensor({self._data})"
@classmethod
def pack(cls, tensor: torch.Tensor):
out_features, in_features = tensor.shape
data_int32 = pack_fp8_as_int32(tensor.T) # pack fp8 data to in32.
perm = torch.empty(0, dtype=torch.int, device=tensor.device)
data_int32 = torch.ops.quanto.pack_fp8_marlin(
b_q_weight=data_int32, perm=perm, size_k=in_features, size_n=out_features, num_bits=8
)
return cls(data_int32, size=tensor.size(), stride=tensor.stride())
def unpack(self) -> torch.Tensor:
"""
Reinterprets the packed tensor (a, b) of type int32 and in the marlin order, to a tensor (a * 4, b) of type float8_e4m3fn, in the natural order.
"""
float8_data = unpack_int32_to_fp8(self._data)
# complex indexing is not implemented for 'Float8_e4m3fn'
uint8_data = float8_data.view(torch.uint8)
n_rows, n_col = uint8_data.shape
# swap columns
column_map = get_column_permutation(n_col=n_col)
uint8_data = uint8_data.T.contiguous()
uint8_data = uint8_data[column_map]
uint8_data = uint8_data.T.contiguous()
uint8_data = uint8_data.reshape(uint8_data.shape[0] * 4, -1)
# swap rows
row_map = get_row_permutation(n_rows=n_rows)
uint8_data = uint8_data[row_map]
float8_data = uint8_data.view(torch.float8_e4m3fn)
float8_data = float8_data.T # As we originally transposed in `pack_fp8_as_int32`
return float8_data
@property
def dtype(self):
return torch.int32
def __tensor_flatten__(self):
inner_tensors = ["_data"]
# Since meta can be used for serialization, use only AST compatible strings
meta = {
"size": str(list(self.size())),
"stride": str(self.stride()),
}
return inner_tensors, meta
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert len(inner_tensors) == 1
assert len(meta) == 2
data = inner_tensors["_data"]
# Meta should contain only AST compatible strings
size = ast.literal_eval(meta["size"])
stride = ast.literal_eval(meta["stride"])
return MarlinF8PackedTensor(data, size, stride)
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, op, types, args, kwargs=None):
# Convert back to tensor before calling any operation except detach and move
if op.overloadpacket is torch.ops.aten.detach:
t = args[0]
data = op(t._data)
return cls(data, t.size(), t.stride())
elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to):
t = args[0]
dtype = kwargs.get("dtype", torch.int32)
if dtype != torch.int32:
raise ValueError(f"MarlinF8PackedTensor are torch.int32 only and cannot be moved to {dtype}.")
device = kwargs.get("device", t.device)
if device.type == "cuda":
data_kwargs = copy(kwargs)
data_kwargs["dtype"] = t._data.dtype
data = op(t._data, **data_kwargs)
return cls(data, t.size(), t.stride())
else:
return t.unpack().to(device)
else:
args, kwargs = pytree.tree_map_only(cls, lambda x: x.unpack(), (args, kwargs or {}))
return op(*args, **kwargs)