cpp/core/utils/qat/QatCodec.cc (239 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <arrow/result.h> #include <arrow/util/compression.h> #include <arrow/util/logging.h> #include <qatseqprod.h> #include <qatzip.h> #include <zstd.h> #include <algorithm> #include <mutex> #include "QatCodec.h" #define QZ_INIT_FAIL(rc) ((QZ_OK != (rc)) && (QZ_DUPLICATE != (rc))) #define QZ_SETUP_SESSION_FAIL(rc) (QZ_PARAMS == (rc) || QZ_NOSW_NO_HW == (rc) || QZ_NOSW_LOW_MEM == (rc)) namespace gluten { namespace qat { class QatZipCodec : public arrow::util::Codec { protected: explicit QatZipCodec(int compressionLevel) : compressionLevel_(compressionLevel) {} ~QatZipCodec() { static_cast<void>(qzTeardownSession(&qzSession_)); static_cast<void>(qzClose(&qzSession_)); } arrow::Result<int64_t> Decompress(int64_t inputLen, const uint8_t* input, int64_t outputLen, uint8_t* output) override { uint32_t compressedSize = static_cast<uint32_t>(inputLen); uint32_t uncompressedSize = static_cast<uint32_t>(outputLen); int ret = qzDecompress(&qzSession_, input, &compressedSize, output, &uncompressedSize); if (ret == QZ_OK) { return static_cast<int64_t>(uncompressedSize); } else if (ret == QZ_PARAMS) { return arrow::Status::IOError("QAT decompression failure: params is invalid"); } else if (ret == QZ_FAIL) { return arrow::Status::IOError("QAT decompression failure: Function did not succeed"); } else { return arrow::Status::IOError("QAT decompression failure with error:", ret); } } int64_t MaxCompressedLen(int64_t inputLen, const uint8_t* ARROW_ARG_UNUSED(input)) override { ARROW_DCHECK_GE(inputLen, 0); return qzMaxCompressedLength(static_cast<size_t>(inputLen), &qzSession_); } arrow::Result<int64_t> Compress(int64_t inputLen, const uint8_t* input, int64_t outputLen, uint8_t* output) override { uint32_t uncompressedSize = static_cast<uint32_t>(inputLen); uint32_t compressedSize = static_cast<uint32_t>(outputLen); int ret = qzCompress(&qzSession_, input, &uncompressedSize, output, &compressedSize, 1); if (ret == QZ_OK) { return static_cast<int64_t>(compressedSize); } else if (ret == QZ_PARAMS) { return arrow::Status::IOError("QAT compression failure: params is invalid"); } else if (ret == QZ_FAIL) { return arrow::Status::IOError("QAT compression failure: function did not succeed"); } else { return arrow::Status::IOError("QAT compression failure with error:", ret); } } arrow::Result<std::shared_ptr<arrow::util::Compressor>> MakeCompressor() override { return arrow::Status::NotImplemented("Streaming compression unsupported with QAT"); } arrow::Result<std::shared_ptr<arrow::util::Decompressor>> MakeDecompressor() override { return arrow::Status::NotImplemented("Streaming decompression unsupported with QAT"); } int compression_level() const override { return compressionLevel_; } int compressionLevel_; QzSession_T qzSession_ = {0}; }; class QatGZipCodec final : public QatZipCodec { public: QatGZipCodec(QzPollingMode_T pollingMode, int compressionLevel) : QatZipCodec(compressionLevel) { auto rc = qzInit(&qzSession_, /* sw_backup = */ 1); if (QZ_INIT_FAIL(rc)) { ARROW_LOG(WARNING) << "qzInit failed with error: " << rc; } else { QzSessionParamsDeflate_T params; qzGetDefaultsDeflate(&params); // get the default value. params.common_params.polling_mode = pollingMode; params.common_params.comp_lvl = compressionLevel; rc = qzSetupSessionDeflate(&qzSession_, &params); if (QZ_SETUP_SESSION_FAIL(rc)) { ARROW_LOG(WARNING) << "qzSetupSession failed with error: " << rc; } } } arrow::Compression::type compression_type() const override { return arrow::Compression::GZIP; } int minimum_compression_level() const override { return QZ_DEFLATE_COMP_LVL_MINIMUM; } int maximum_compression_level() const override { return QZ_DEFLATE_COMP_LVL_MAXIMUM; } int default_compression_level() const override { return QZ_COMP_LEVEL_DEFAULT; } }; bool supportsCodec(const std::string& codec) { return !codec.empty() && std::any_of(kQatSupportedCodec.begin(), kQatSupportedCodec.end(), [&](const auto& qatCodec) { return qatCodec == codec; }); } std::unique_ptr<arrow::util::Codec> makeQatGZipCodec(QzPollingMode_T pollingMode, int compressionLevel) { return std::unique_ptr<arrow::util::Codec>(new QatGZipCodec(pollingMode, compressionLevel)); } std::unique_ptr<arrow::util::Codec> makeDefaultQatGZipCodec() { return makeQatGZipCodec(QZ_BUSY_POLLING, QZ_COMP_LEVEL_DEFAULT); } namespace { constexpr int kZSTDDefaultCompressionLevel = 1; arrow::Status ZSTDError(size_t ret, const char* prefix_msg) { return arrow::Status::IOError(prefix_msg, ZSTD_getErrorName(ret)); } void logZstdOnError(size_t ret, const char* prefixMsg) { if (ZSTD_isError(ret)) { ARROW_LOG(WARNING) << prefixMsg << ZSTD_getErrorName(ret); } } } // namespace class QatDevice { public: QatDevice() { if (QZSTD_startQatDevice() == QZSTD_FAIL) { ARROW_LOG(WARNING) << "QZSTD_startQatDevice failed"; } else { initialized_ = true; } } ~QatDevice() { if (initialized_) { QZSTD_stopQatDevice(); } } static std::shared_ptr<QatDevice> getInstance() { std::call_once(initQat_, []() { instance_ = std::make_shared<QatDevice>(); }); return instance_; } bool deviceInitialized() { return initialized_; } private: inline static std::shared_ptr<QatDevice> instance_; inline static std::once_flag initQat_; bool initialized_{false}; }; class QatZstdCodec final : public arrow::util::Codec { public: explicit QatZstdCodec(int compressionLevel) : compressionLevel_(compressionLevel) {} ~QatZstdCodec() { if (initCCtx_) { ZSTD_freeCCtx(zc_); if (sequenceProducerState_) { QZSTD_freeSeqProdState(sequenceProducerState_); } } } arrow::Compression::type compression_type() const override { return arrow::Compression::ZSTD; } arrow::Result<int64_t> Decompress(int64_t inputLen, const uint8_t* input, int64_t outputLen, uint8_t* output) override { if (output == nullptr) { // We may pass a NULL 0-byte output buffer but some zstd versions demand // a valid pointer: https://github.com/facebook/zstd/issues/1385 static uint8_t emptyBuffer; DCHECK_EQ(outputLen, 0); output = &emptyBuffer; } size_t ret = ZSTD_decompress(output, static_cast<size_t>(outputLen), input, static_cast<size_t>(inputLen)); if (ZSTD_isError(ret)) { return ZSTDError(ret, "ZSTD decompression failed: "); } if (static_cast<int64_t>(ret) != outputLen) { return arrow::Status::IOError("Corrupt ZSTD compressed data."); } return static_cast<int64_t>(ret); } int64_t MaxCompressedLen(int64_t inputLen, const uint8_t* ARROW_ARG_UNUSED(input)) override { DCHECK_GE(inputLen, 0); return ZSTD_compressBound(static_cast<size_t>(inputLen)); } arrow::Result<int64_t> Compress(int64_t inputLen, const uint8_t* input, int64_t outputLen, uint8_t* output) override { RETURN_NOT_OK(initCCtx()); size_t ret = ZSTD_compress2(zc_, output, static_cast<size_t>(outputLen), input, static_cast<size_t>(inputLen)); if (ZSTD_isError(ret)) { return ZSTDError(ret, "ZSTD compression failed: "); } return static_cast<int64_t>(ret); } arrow::Result<std::shared_ptr<arrow::util::Compressor>> MakeCompressor() override { return arrow::Status::NotImplemented("Streaming compression unsupported with QAT"); } arrow::Result<std::shared_ptr<arrow::util::Decompressor>> MakeDecompressor() override { return arrow::Status::NotImplemented("Streaming decompression unsupported with QAT"); } int minimum_compression_level() const override { return ZSTD_minCLevel(); } int maximum_compression_level() const override { return ZSTD_maxCLevel(); } int default_compression_level() const override { return kZSTDDefaultCompressionLevel; } int compression_level() const override { return compressionLevel_; } private: int compressionLevel_; ZSTD_CCtx* zc_; bool initCCtx_{false}; std::shared_ptr<QatDevice> qatDevice_; void* sequenceProducerState_{nullptr}; arrow::Status initCCtx() { if (initCCtx_) { return arrow::Status::OK(); } zc_ = ZSTD_createCCtx(); logZstdOnError( ZSTD_CCtx_setParameter(zc_, ZSTD_c_compressionLevel, compressionLevel_), "ZSTD_CCtx_setParameter failed on ZSTD_c_compressionLevel: "); if (!qatDevice_) { qatDevice_ = QatDevice::getInstance(); } if (qatDevice_->deviceInitialized()) { sequenceProducerState_ = QZSTD_createSeqProdState(); /* register qatSequenceProducer */ ZSTD_registerSequenceProducer(zc_, sequenceProducerState_, qatSequenceProducer); /* Enable sequence producer fallback */ logZstdOnError( ZSTD_CCtx_setParameter(zc_, ZSTD_c_enableSeqProducerFallback, 1), "ZSTD_CCtx_setParameter failed on ZSTD_c_enableSeqProducerFallback: "); } initCCtx_ = true; return arrow::Status::OK(); } }; std::unique_ptr<arrow::util::Codec> makeQatZstdCodec(int compressionLevel) { return std::unique_ptr<arrow::util::Codec>(new QatZstdCodec(compressionLevel)); } std::unique_ptr<arrow::util::Codec> makeDefaultQatZstdCodec() { return makeQatZstdCodec(kZSTDDefaultCompressionLevel); } } // namespace qat } // namespace gluten