optimum/quanto/library/qbytes_mm.py (82 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from packaging import version __all__ = [] torch.library.define("quanto::qbytes_mm", "(Tensor A, Tensor B, Tensor scales) -> Tensor") def qbytes_mm(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: activations = activations.to(output_scales.dtype) if weights.dtype.is_floating_point: # Float8 requires an explicit promotion weights = weights.to(output_scales.dtype) # Apply the scale to the weights before the matrix multiplication to put them back # into their initial numerical range and avoid overflows weights = output_scales * weights return torch.matmul(activations, weights.t()) def qbytes_int_mm(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: in_features = activations.shape[-1] out_features = weights.shape[0] # torch._int_mm works on transposed weights, i.e (in_features, out_features) weights = weights.t() if activations.ndim == 2: out_data = torch._int_mm(activations, weights) else: output_shape = activations.shape[:-1] + (out_features,) out_data = torch._int_mm(activations.reshape(-1, in_features), weights) out_data = out_data.reshape(output_shape) # We must evaluate the output as float32 because the multiplication # of the int32 data by the scales might overflow fp32_output = out_data.to(torch.float32) * output_scales.t() return fp32_output.to(output_scales.dtype) def qbytes_int8pack_mm(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: # torch._weight_int8pack_mm expects a vector of scales output_scales = output_scales.flatten() if activations.ndim == 2: return torch._weight_int8pack_mm(activations, weights, output_scales) else: in_features = activations.shape[-1] out_features = weights.shape[0] output_shape = activations.shape[:-1] + (out_features,) out_data = torch._weight_int8pack_mm(activations.reshape(-1, in_features), weights, output_scales) return out_data.reshape(output_shape) @torch.library.impl("quanto::qbytes_mm", "default") def qbytes_mm_impl_default( activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor ) -> torch.Tensor: return qbytes_mm(activations, weights, output_scales) @torch.library.impl("quanto::qbytes_mm", "CUDA") def qbytes_mm_impl_cuda(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: assert activations.ndim in (2, 3) in_features = activations.shape[-1] tokens = activations.shape[0] if activations.ndim == 2 else activations.shape[0] * activations.shape[1] out_features = weights.shape[0] if ( activations.dtype == torch.int8 and weights.dtype == torch.int8 and tokens > 16 and tokens % 8 == 0 and in_features % 8 == 0 and out_features % 8 == 0 ): return qbytes_int_mm(activations, weights, output_scales) return qbytes_mm(activations, weights, output_scales) @torch.library.impl("quanto::qbytes_mm", "CPU") def qbytes_mm_impl_cpu(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: if ( # FIXME: accuracy issues with 2.4.x version.parse(torch.__version__).release >= version.parse("2.6.0").release and activations.dtype == torch.int8 and weights.dtype == torch.int8 ): return qbytes_int_mm(activations, weights, output_scales) in_features = activations.shape[-1] if activations.dtype == torch.bfloat16 and weights.dtype == torch.int8 and in_features % 4 == 0: if type(activations) is not torch.Tensor: activations = activations.dequantize() return qbytes_int8pack_mm(activations, weights, output_scales) return qbytes_mm(activations, weights, output_scales) @torch.library.impl("quanto_py::qbytes_mm", "MPS") def qbytes_mm_impl_mps(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: in_features = activations.shape[-1] out_features = weights.shape[0] if ( version.parse(torch.__version__).release >= version.parse("2.4.0").release and activations.dtype == torch.bfloat16 and weights.dtype == torch.int8 and in_features % 32 == 0 and out_features % 32 == 0 ): if type(activations) is not torch.Tensor: activations = activations.dequantize() return qbytes_int8pack_mm(activations, weights, output_scales) return qbytes_mm(activations, weights, output_scales)