optimum/quanto/tensor/weights/marlin/fp8/qbits.py (102 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import torch from ....function import QuantizedLinearFunction from ....qtype import qfloat8_e4m3fn, qtypes from ...qbytes import WeightQBytesTensor from .packed import MarlinF8PackedTensor, get_scale_perms __all__ = ["MarlinF8QBytesTensor"] class MarlinF8QBytesLinearFunction(QuantizedLinearFunction): @staticmethod def forward(ctx, input, other, bias=None): ctx.save_for_backward(input, other) input_shape = input.shape if input.ndim > 2: input = input.reshape(-1, input_shape[-1]) output = torch.ops.quanto.gemm_f16f8_marlin( input, b_q_weight=other._data._data, b_scales=other._scale, # .to(input.dtype) workspace=other._workspace, num_bits=8, size_m=input.shape[0], size_n=other._scale.shape[1], size_k=input.shape[1], ) if len(input_shape) > 2: output = output.reshape(input_shape[:-1] + (other._scale.shape[1],)) return output class MarlinF8QBytesTensor(WeightQBytesTensor): @staticmethod def __new__(cls, qtype, axis, size, stride, data, scale, requires_grad=False): assert data.device.type == "cuda" assert data.device == scale.device return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad ) def __init__(self, qtype, axis, size, stride, data, scale, requires_grad=False): assert axis == 0 assert data.ndim == 2 out_features = size[0] self._workspace = torch.zeros(out_features // 64 * 16, dtype=torch.int, device=data.device) # TODO: Here we should use `not isinstance(data, MarlinF8PackedTensor)`, but `torch.compile` is bugged when using that. # Somewhere in the internals of torch.compile, `data` gets converted to a `torch._subclasses.fake_tensor.FakeTensor` not inheriting from `MarlinF8PackedTensor` and torch then goes into the wrong controlflow. # Reference: https://pytorch.slack.com/archives/C033H6DJSJU/p1721837684035049 if data.dtype != torch.int32: assert scale.shape == (out_features, 1) scale_perm_single = get_scale_perms() scale = scale.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] scale = scale.reshape(-1, out_features).contiguous() data_packed = MarlinF8PackedTensor.pack(data) # pack fp8 data to in32, and apply marlier re-ordering. else: # When freezing (`model.freeze()`), the data is already a MarlinF8PackedTensor and scale is already repacked. data_packed = data super().__init__( qtype, axis, size, stride, data_packed, scale, activation_qtype=qfloat8_e4m3fn, requires_grad=requires_grad ) def dequantize(self): float8_data = self._data.unpack() scale_perm_single = get_scale_perms() # `scale_perm_single` holds the mapping of natural to marlin, so inverse it here. scale_perm_single_rev = torch.empty_like(scale_perm_single) scale_perm_single_rev[scale_perm_single] = torch.arange(len(scale_perm_single)) scale_reordered = self._scale.reshape((-1, len(scale_perm_single_rev)))[:, scale_perm_single_rev] scale_reordered = scale_reordered.reshape(-1, self._scale.shape[1]).contiguous() return float8_data.to(scale_reordered.dtype) * scale_reordered.T def __repr__(self): return f"MarlinF8QBytesTensor({self._data}, scale={self._scale}, dtype={self.dtype})" def weight_qbytes_tensor(self): data = self._data.unpack() scale_perm_single = get_scale_perms() # `scale_perm_single` holds the mapping of natural to marlin, so inverse it here. scale_perm_single_rev = torch.empty_like(scale_perm_single) scale_perm_single_rev[scale_perm_single] = torch.arange(len(scale_perm_single)) scale_reordered = self._scale.reshape((-1, len(scale_perm_single_rev)))[:, scale_perm_single_rev] scale_reordered = scale_reordered.reshape(-1, self._scale.shape[1]).t().contiguous() return WeightQBytesTensor( self._qtype, self._axis, self.size(), self.stride(), data, scale_reordered, self.activation_qtype ) def __tensor_flatten__(self): inner_tensors = ["_data", "_scale"] meta = { "qtype": self._qtype.name, "axis": str(self._axis), "size": str(list(self.size())), "stride": str(list(self.stride())), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 2 assert len(meta) == 4 data, scale = inner_tensors["_data"], inner_tensors["_scale"] # Meta should only contain strings, AST compatible except qtype qtype = qtypes[meta["qtype"]] axis = ast.literal_eval(meta["axis"]) size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return MarlinF8QBytesTensor(qtype, axis, size, stride, data, scale) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): """Dispatch torch functions applied on this subtensor This method is called whenever a torch function (such as `torch.nn.functional.linear`) is called with at least one parameter coresponding to this subtensor: - if a quantized implementation exists for the selected function, it is called, - otherwise, the original implementation is called, deactivating further functional dispatch. During the execution of the standard torch function, a second-level of dispatch will happen, but this time directly on individual torch Tensor operations (mainly ATEN). """ kwargs = kwargs or {} if func is torch.nn.functional.linear: def qlinear(input, other, bias=None): return MarlinF8QBytesLinearFunction.apply(input, other, bias) return qlinear(*args, **kwargs) elif func is torch.equal: input, other = args return input.equal(other) # Defer to operations dispatcher with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs)