maga_transformer/cpp/utils/quantization.h (215 lines of code) (raw):

/* * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "maga_transformer/cpp/utils/QuantInfo.h" #include "stdlib.h" #include <cstdint> #include <string> namespace tensorrt_llm { namespace common { class QuantAlgo{ public: QuantAlgo() = default; QuantAlgo(rtp_llm::QuantAlgo const& quant_algo) : weight_bits_(quant_algo.getWeightBits()) , group_size_(quant_algo.getGroupSize()) , weight_only_(quant_algo.isWeightOnlyPerCol() || quant_algo.isGptq() || quant_algo.isAwq()) , sq_int8_(quant_algo.isSmoothQuant() || quant_algo.isOmniQuant()) , fp8_(quant_algo.isFp8()) {} QuantAlgo(int weight_bits, int64_t group_size, bool weight_only, bool sq_int8, bool fp8) : weight_bits_(weight_bits) , group_size_(group_size) , weight_only_(weight_only) , sq_int8_(sq_int8) , fp8_(fp8) {} int getWeightBits() const { return weight_bits_; } int getGroupSize() const { return group_size_; } bool weightOnly() const { return weight_only_; } bool smoothQuantInt8() const { return sq_int8_; } bool fp8() const { return fp8_; } private: int weight_bits_ = 0; int group_size_ = 0; bool weight_only_ = false; bool sq_int8_ = false; bool fp8_ = false; }; class QuantMode { // [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/quantization/mode.py public: using BaseType = std::uint32_t; explicit constexpr QuantMode(BaseType value) noexcept : mValue{value} { } QuantMode() noexcept = default; constexpr QuantMode(QuantMode const&) noexcept = default; constexpr QuantMode& operator=(const QuantMode& other) noexcept = default; static constexpr QuantMode none() noexcept { return QuantMode(BaseType(0)); } static constexpr QuantMode int4Weights() noexcept { return QuantMode(BaseType(1u) << 0); } static constexpr QuantMode int8Weights() noexcept { return QuantMode(BaseType(1u) << 1); } static constexpr QuantMode activations() noexcept { return QuantMode(BaseType(1u) << 2); } static constexpr QuantMode perChannelScaling() noexcept { return QuantMode(BaseType(1u) << 3); } static constexpr QuantMode perTokenScaling() noexcept { return QuantMode(BaseType(1u) << 4); } static constexpr QuantMode perGroupScaling() noexcept { return QuantMode(BaseType(1u) << 5); } static constexpr QuantMode int8KvCache() noexcept { return QuantMode(BaseType(1u) << 6); } static constexpr QuantMode fp8KvCache() noexcept { return QuantMode(BaseType(1u) << 7); } static constexpr QuantMode fp8Qdq() noexcept { return QuantMode(BaseType(1u) << 8); } constexpr BaseType value() const noexcept { return mValue; } constexpr bool isSet(QuantMode const& mode) const noexcept { return (mValue & mode.value()) == mode.value(); } constexpr bool hasInt4Weights() const noexcept { return isSet(int4Weights()); } constexpr bool hasInt8Weights() const noexcept { return isSet(int8Weights()); } constexpr bool hasActivations() const noexcept { return isSet(activations()); } constexpr bool hasPerChannelScaling() const noexcept { return isSet(perChannelScaling()); } constexpr bool hasPerTokenScaling() const noexcept { return isSet(perTokenScaling()); } constexpr bool hasPerGroupScaling() const noexcept { return isSet(perGroupScaling()); } constexpr bool hasStaticActivationScaling() const noexcept { return !hasPerTokenScaling(); } constexpr bool hasInt8KvCache() const noexcept { return isSet(int8KvCache()); } constexpr bool hasFp8KvCache() const noexcept { return isSet(fp8KvCache()); } constexpr bool hasFp8Qdq() const noexcept { return isSet(fp8Qdq()); } constexpr bool hasKvCacheQuant() const noexcept { return hasInt8KvCache() || hasFp8KvCache(); } static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false, bool perToken = false, bool perChannel = false, bool useInt4Weights = false, bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false) { QuantMode quantMode{}; if (quantizeWeights) { if (useInt4Weights) quantMode += int4Weights(); else quantMode += int8Weights(); } if (quantizeActivations) { quantMode += activations(); } if (perChannel) { quantMode += QuantMode::perChannelScaling(); } if (perToken) { quantMode += QuantMode::perTokenScaling(); } if (useInt8KvCache) { quantMode += int8KvCache(); } if (useFp8KvCache) { quantMode += fp8KvCache(); } if (useFp8Qdq) { quantMode += fp8Qdq(); } return quantMode; } constexpr QuantMode operator+(const QuantMode& other) const noexcept { return QuantMode(mValue | other.mValue); } constexpr QuantMode& operator+=(const QuantMode& other) noexcept { return *this = *this + other; } constexpr QuantMode operator-(const QuantMode& other) const noexcept { return QuantMode(mValue & ~other.mValue); } constexpr QuantMode& operator-=(const QuantMode& other) noexcept { return *this = *this - other; } constexpr bool operator==(const QuantMode& other) const noexcept { return mValue == other.mValue; } constexpr bool operator!=(const QuantMode& other) const noexcept { return !(*this == other); } private: BaseType mValue{0}; }; } // namespace common }