optimum/quanto/tensor/weights/marlin/int4/packed.py (113 lines of code) (raw):

import ast from copy import copy import numpy as np import torch from torch.utils import _pytree as pytree from ...packing import unpack_int32_to_uint8 from ...reordering import reorder, reverse __all__ = ["MarlinInt4PackedTensor"] # From: https://github.com/IST-DASLab/marlin/blob/master/marlin/__init__.py#L40 # this func does 2 things # 1. 1 thread can load 32 4bit == 128bit weights used for mulitple mma instructions at once # 2. faster dequant via parallel half2 mul def _get_perm(): perm = [] # 32 == # of threads in 1 warp for i in range(32): perm1 = [] # column id in 16x8 weight block # check https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float col = i // 4 # 1 32bit (int32) == 8 4bit, 1 thread has 4 weights per 16x8 & 4bit weights are packed in int32, so needs 2 16x8 == 1 16x16 blocks for block in [0, 1]: # row id in 16x8 weight block # check https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float for row in [ 2 * (i % 4), 2 * (i % 4) + 1, 2 * (i % 4 + 4), 2 * (i % 4 + 4) + 1, ]: # 8 weights used for 1 thread (16x16 block) are contiguous in memory via interleaving # e.g. T0 uses (0, 16, 128, 144, 8, 24, 136, 152) perm1.append(16 * row + col + 8 * block) # 1 128bit (int4) == 4 32bit, 1 thread loads 128bit at once, so needs 4 16x16 == 1 16x64 blocks for j in range(4): # 32 weights loaded by 1 thread (16x64 block) are contiguous in memory via interleaving # e.g. T0 uses ((0 ~ 152) + 0 * 256, (0 ~ 152) + 1 * 256, ..., (0 ~ 152) + 3 * 256) perm.extend([p + 256 * j for p in perm1]) perm = np.array(perm) # for faster dequant # check https://arxiv.org/pdf/2211.10017 interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) perm = perm.reshape((-1, 8))[:, interleave].ravel() perm = torch.from_numpy(perm) return perm _perm = _get_perm() _rev_perm = reverse(_perm) # From: https://github.com/IST-DASLab/marlin/blob/master/marlin/__init__.py#L102 def pack(unpacked: torch.Tensor): w = unpacked N, K = w.shape w = unpacked.t() # 16 == tile size, marlin uses 16x16 tile, so 16x16 grouping via interleaving w = w.reshape((K // 16, 16, N // 16, 16)) w = w.permute((0, 2, 1, 3)) w = w.reshape((K // 16, N * 16)) res = w # _perm.numel() == 1024 == 4 16x16, permute weights with 4 16x16 unit for efficient mma + dequant res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape) p = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32) res = res.cpu().numpy().astype(np.uint32) for i in range(8): p |= res[:, i::8] << 4 * i p = torch.from_numpy(p.astype(np.int32)).to(w.device) return p def unpack(packed, orig_shape): N, K = orig_shape # Unpack to recover individual values unpacked = unpack_int32_to_uint8(packed, bits=4).to(torch.uint8) # Recover the original ordering unpacked = reorder(unpacked, _rev_perm) # Apply block permutations in the reverse order unpacked = unpacked.reshape(K // 16, N // 16, 16, 16) unpacked = unpacked.permute((0, 2, 1, 3)) unpacked = unpacked.reshape(K, N) return unpacked.t() class MarlinInt4PackedTensor(torch.Tensor): @staticmethod def __new__(cls, data, size, stride, requires_grad=False): assert data.device.type == "cuda" assert data.dtype == torch.int32 assert requires_grad is False return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=torch.uint8, device=data.device, requires_grad=requires_grad ) def __init__(self, data, size, stride, requires_grad=False): self._data = data def __repr__(self): return f"MarlinInt4PackedTensor({self._data})" @classmethod def pack(cls, t): data = pack(t) return MarlinInt4PackedTensor(data, t.size(), t.stride()) def unpack(self): return unpack(self._data, self.size()) @property def dtype(self): return torch.uint8 def __tensor_flatten__(self): inner_tensors = ["_data"] meta = { "size": str(list(self.size())), "stride": str(self.stride()), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 1 assert len(meta) == 2 data = inner_tensors["_data"] size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return MarlinInt4PackedTensor(data, size, stride) __torch_function__ = torch._C._disabled_torch_function_impl @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): if op.overloadpacket is torch.ops.aten.detach: t = args[0] data = op(t._data) return MarlinInt4PackedTensor(data, t.size(), t.stride()) elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to): t = args[0] dtype = kwargs.get("dtype", torch.uint8) if dtype != torch.uint8: raise ValueError(f"MarlinInt4PackedTensor are torch.uint8 only and cannot be moved to {dtype}.") device = kwargs.get("device", t.device) if device.type == "cuda": data_kwargs = copy(kwargs) data_kwargs["dtype"] = t._data.dtype data = op(t._data, **data_kwargs) return MarlinInt4PackedTensor(data, t.size(), t.stride()) return t.unpack() args, kwargs = pytree.tree_map_only(MarlinInt4PackedTensor, lambda x: x.unpack(), (args, kwargs or {})) return op(*args, **kwargs) def numpy(self): return self.unpack().cpu().numpy()