cdk/protocol/mysqlx/protocol_compression.cc (291 lines of code) (raw):

/* * Copyright (c) 2015, 2024, Oracle and/or its affiliates. * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License, version 2.0, as * published by the Free Software Foundation. * * This program is designed to work with certain software (including * but not limited to OpenSSL) that is licensed under separate terms, as * designated in a particular file or component or in included license * documentation. The authors of MySQL hereby grant you an additional * permission to link the program and your derivative works with the * separately licensed software that they have either included with * the program or referenced in the documentation. * * Without limiting anything contained in the foregoing, this file, * which is part of Connector/C++, is also subject to the * Universal FOSS Exception, version 1.0, a copy of which can be found at * https://oss.oracle.com/licenses/universal-foss-exception. * * This program is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. * See the GNU General Public License, version 2.0, for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software Foundation, Inc., * 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA */ /* Implementation of mysqlx protocol compression ============================================= */ // Note: on Windows this includes windows.h #include <mysql/cdk/foundation/common.h> #include "protocol.h" PUSH_SYS_WARNINGS_CDK #include <memory.h> // for memcpy POP_SYS_WARNINGS_CDK #include <sstream> using namespace cdk::foundation; using namespace cdk::protocol::mysqlx; namespace cdk { namespace protocol { namespace mysqlx { /* ZLib Compression Algorithm functions ==================================== */ void Compression_zlib::init() { if (m_zlib_inited) return; // Initial functions mapping, keep the internal implementation m_c_zstream.zalloc = Z_NULL; m_c_zstream.zfree = Z_NULL; m_c_zstream.opaque = Z_NULL; m_c_zstream.total_out = 0; // TODO: Make the compression level adjustable if (deflateInit(&m_c_zstream, 9) != Z_OK) throw_error("Could not initialize compression output stream"); // Initial functions mapping, keep the internal implementation m_u_zstream.zalloc = Z_NULL; m_u_zstream.zfree = Z_NULL; m_u_zstream.opaque = Z_NULL; if (inflateInit(&m_u_zstream) != Z_OK) throw_error("Could not initialize compression input stream"); m_zlib_inited = true; } size_t Compression_zlib::compress(byte *src, size_t len) { size_t total_compressed_len = m_c_zstream.total_out; m_c_zstream.next_in = src; // Input buffer with uncompressed data m_c_zstream.avail_in = (uInt)len; // Length of uncompressed data /* TODO: Do smarter allocation for compression buffer since the upper bound might be quite redundant. */ size_t deflate_size = deflateBound(&m_c_zstream, (uLong)len); // This will reallocate the buffer if needed and get its address m_c_zstream.next_out = m_protocol_compression.get_out_buf(deflate_size); m_c_zstream.avail_out = (uInt)m_protocol_compression.get_out_buf_len(); int res = deflate(&m_c_zstream, Z_SYNC_FLUSH); if (res != Z_OK) return 0; return m_c_zstream.total_out - total_compressed_len; } size_t Compression_zlib::uncompress(byte *dst, size_t dest_size, size_t compressed_size, size_t &bytes_consumed) { m_u_zstream.next_in = m_protocol_compression.get_inp_buf(); m_u_zstream.avail_in = (uInt)compressed_size; m_u_zstream.next_out = dst; m_u_zstream.avail_out = (uInt)dest_size; int inflate_res = inflate(&m_u_zstream, Z_SYNC_FLUSH); if (inflate_res != Z_OK) { inflateReset(&m_u_zstream); return COMPRESSION_ERROR; } // The number of processed compressed bytes bytes_consumed = compressed_size - m_u_zstream.avail_in; // The number of uncompressed bytes return dest_size - m_u_zstream.avail_out; } Compression_zlib::~Compression_zlib() NOEXCEPT { if (m_zlib_inited) { deflateEnd(&m_c_zstream); inflateEnd(&m_u_zstream); } } /* LZ4 Compression Algorithm functions =================================== */ void Compression_lz4::init() { if (m_dctx && m_cctx) return; if (m_dctx == nullptr && LZ4F_isError(LZ4F_createDecompressionContext(&m_dctx, LZ4F_getVersion()))) throw_error("Error creating LZ4 decompression context"); if (m_cctx == nullptr && LZ4F_isError(LZ4F_createCompressionContext(&m_cctx, LZ4F_getVersion()))) throw_error("Error creating LZ4 compression context"); m_lz4f_pref.autoFlush = 1; m_lz4f_pref.frameInfo.contentSize = 0; } size_t Compression_lz4::compress(byte *src, size_t len) { auto check_lz4_result = [this](size_t result) { if (LZ4F_isError(result)) { LZ4F_freeCompressionContext(m_cctx); m_cctx = nullptr; throw_error(string{"LZ4: "} + LZ4F_getErrorName(result)); } }; if (len > LZ4_MAX_INPUT_SIZE) throw_error("Data for compression is too long"); // Header size + The worst case compressed data size_t wbuf_size = LZ4F_HEADER_SIZE_MAX + LZ4F_compressBound(len, &m_lz4f_pref); // Allocate wr buf and adjust the offset for writing data byte *dest_buf_adjusted = m_protocol_compression.get_out_buf(wbuf_size); // Update with the real buffer length wbuf_size = m_protocol_compression.get_out_buf_len(); size_t begin_result = LZ4F_compressBegin(m_cctx, (void*)dest_buf_adjusted, wbuf_size, &m_lz4f_pref); check_lz4_result(begin_result); dest_buf_adjusted += begin_result; wbuf_size -= begin_result; void *src_adjusted = (void*)src; size_t compression_result = LZ4F_compressUpdate(m_cctx, (void*)dest_buf_adjusted, wbuf_size, src_adjusted, len, nullptr); check_lz4_result(compression_result); dest_buf_adjusted += compression_result; wbuf_size -= compression_result; assert(4 <= wbuf_size); size_t flush_result = LZ4F_compressEnd(m_cctx, (void*)dest_buf_adjusted, wbuf_size, nullptr); check_lz4_result(flush_result); return begin_result + flush_result + compression_result; } size_t Compression_lz4::uncompress(byte *dst, size_t dest_size, size_t compressed_size, size_t &bytes_consumed) { size_t bytes_processed = 0; size_t initial_dest_size = dest_size; while (true) { size_t bytes_to_write = dest_size; size_t current_bytes_processed = compressed_size - bytes_processed; size_t result = LZ4F_decompress(m_dctx, (void*)dst, &bytes_to_write, (void*)(m_protocol_compression.get_inp_buf() + bytes_processed), &current_bytes_processed, nullptr); if (LZ4F_isError(result)) { LZ4F_resetDecompressionContext(m_dctx); throw_error("Problem during LZ4 decompression"); } if (dest_size < bytes_to_write) { throw_error("Decompression buffer is not large enough"); } bytes_processed += current_bytes_processed; dst += bytes_to_write; // Adjust the buffer writing position dest_size -= bytes_to_write; // Adjust buffer size awailable for writing if (result == 0 || current_bytes_processed == 0 /*bytes_to_write == 0*/) break; } bytes_consumed = bytes_processed; return initial_dest_size - dest_size; } Compression_lz4::~Compression_lz4() NOEXCEPT { if (m_dctx) LZ4F_freeDecompressionContext(m_dctx); if (m_cctx) LZ4F_freeCompressionContext(m_cctx); } /* ZStd Compression Algorithm ========================== */ void Compression_zstd::init() { if (m_c_zstd && m_u_zstd) return; if (m_c_zstd == nullptr) { m_c_zstd = ZSTD_createCStream(); if (ZSTD_isError(ZSTD_initCStream(m_c_zstd, -1))) throw_error("Error creating ZSTD compression stream"); } if (m_u_zstd == nullptr) { m_u_zstd = ZSTD_createDStream(); if (ZSTD_isError(ZSTD_initDStream(m_u_zstd))) throw_error("Error creating ZSTD decompression stream"); } } size_t Compression_zstd::compress(byte *src, size_t len) { size_t estimated_c_size = ZSTD_compressBound(len); ZSTD_outBuffer out_buffer{ m_protocol_compression.get_out_buf(estimated_c_size), estimated_c_size, 0 }; ZSTD_inBuffer in_buffer{ src, len, 0 }; while (in_buffer.pos < in_buffer.size) { size_t result = ZSTD_compressStream(m_c_zstd, &out_buffer, &in_buffer); if (ZSTD_isError(result)) throw_error("ZSTD compression error"); } size_t flush_result = ZSTD_flushStream(m_c_zstd, &out_buffer); if (ZSTD_isError(flush_result)) throw_error("ZSTD flush error"); return out_buffer.pos; } size_t Compression_zstd::uncompress(byte *dst, size_t dest_size, size_t compressed_size, size_t &bytes_consumed) { ZSTD_outBuffer out_buffer{ dst, dest_size, 0 }; ZSTD_inBuffer in_buffer{ m_protocol_compression.get_inp_buf(), compressed_size, 0 }; while (out_buffer.pos < out_buffer.size) { size_t result = ZSTD_decompressStream(m_u_zstd, &out_buffer, &in_buffer); if (ZSTD_isError(result)) throw_error("ZSTD decompression error"); // All input is consumed, do not attempt to go for another iteration if (in_buffer.pos >= in_buffer.size) break; } bytes_consumed = in_buffer.pos; return out_buffer.pos; } Compression_zstd::~Compression_zstd() NOEXCEPT { if (m_u_zstd) ZSTD_freeDStream(m_u_zstd); if (m_c_zstd) ZSTD_freeCStream(m_c_zstd); } /* Protocol compression ==================== */ Protocol_compression::Protocol_compression() { } byte* Protocol_compression::get_out_buf(size_t size) { if (m_c_out_size && size <= m_c_out_size) return m_c_out_buf; byte *tmp = (byte*)realloc(m_c_out_buf, size); if (!tmp) throw_error("Could not reallocate compression output buffer"); m_c_out_buf = tmp; m_c_out_size = size; return m_c_out_buf; } size_t Protocol_compression::do_compress(byte *src, size_t len) { if (!m_algorithm) throw_error("Unknown compression type"); return m_algorithm->compress(src, len); } bool Protocol_compression::uncompress(byte *buf, size_t size) { // If no more data is needed do not uncompress anything if (0 == size) return true; size_t orig_size = size; do { size -= do_uncompress(buf + orig_size - size, size); if (0 == size) return true; if (COMPRESSION_ERROR == size) return false; }while(size); return true; } size_t Protocol_compression::do_uncompress(byte *dst, size_t dest_size) { size_t bytes_uncompressed = 0; size_t bytes_consumed = 0; /* ZSTD can consume the entire input when uncompressing 5 bytes of header and we need to call the uncompression again to obtain the rest of uncompressed data. */ if (m_c_inp_size || m_u_total_size) { if (!m_algorithm) throw_error("Unknown compression type"); bytes_uncompressed = m_algorithm->uncompress (dst, dest_size, m_c_inp_size, bytes_consumed); m_c_inp_offset += bytes_consumed; m_c_inp_size -= bytes_consumed; m_u_total_size -= bytes_uncompressed; } return bytes_uncompressed; } void Protocol_compression::set_compression_type (Compression_type::value compression_type) { m_compression_type = compression_type; switch (m_compression_type) { case Compression_type::DEFLATE: m_algorithm.reset(new Compression_zlib(*this)); break; case Compression_type::LZ4: m_algorithm.reset(new Compression_lz4(*this)); break; case Compression_type::ZSTD: m_algorithm.reset(new Compression_zstd(*this)); break; case Compression_type::NONE: m_algorithm.reset(); break; default: throw_error("Unknown compression type"); } } Protocol_compression::~Protocol_compression() NOEXCEPT { if (m_c_out_buf) free(m_c_out_buf); } }}}