/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include "shuffle/Options.h" #include "shuffle/Utils.h" namespace gluten { class Payload { public: enum Type : uint8_t { kCompressed = 1, kUncompressed = 2, kToBeCompressed = 3 }; Payload(Type type, uint32_t numRows, const std::vector* isValidityBuffer); virtual ~Payload() = default; virtual arrow::Status serialize(arrow::io::OutputStream* outputStream) = 0; virtual arrow::Result> readBufferAt(uint32_t index) = 0; int64_t getCompressTime() const { return compressTime_; } int64_t getWriteTime() const { return writeTime_; } Type type() const { return type_; } uint32_t numRows() const { return numRows_; } uint32_t numBuffers() { return isValidityBuffer_->size(); } const std::vector* isValidityBuffer() const { return isValidityBuffer_; } std::string toString() const; protected: Type type_; uint32_t numRows_; const std::vector* isValidityBuffer_; int64_t compressTime_{0}; int64_t writeTime_{0}; }; // A block represents data to be cached in-memory. // Can be compressed or uncompressed. class BlockPayload : public Payload { public: static arrow::Result> fromBuffers( Payload::Type payloadType, uint32_t numRows, std::vector> buffers, const std::vector* isValidityBuffer, arrow::MemoryPool* pool, arrow::util::Codec* codec); static arrow::Result>> deserialize( arrow::io::InputStream* inputStream, const std::shared_ptr& schema, const std::shared_ptr& codec, arrow::MemoryPool* pool, uint32_t& numRows, int64_t& decompressTime); arrow::Status serialize(arrow::io::OutputStream* outputStream) override; arrow::Result> readBufferAt(uint32_t pos) override; protected: BlockPayload( Type type, uint32_t numRows, std::vector> buffers, const std::vector* isValidityBuffer, arrow::MemoryPool* pool, arrow::util::Codec* codec) : Payload(type, numRows, isValidityBuffer), buffers_(std::move(buffers)), pool_(pool), codec_(codec) {} void setCompressionTime(int64_t compressionTime); std::vector> buffers_; arrow::MemoryPool* pool_; arrow::util::Codec* codec_; }; class InMemoryPayload final : public Payload { public: InMemoryPayload( uint32_t numRows, const std::vector* isValidityBuffer, std::vector> buffers) : Payload(Type::kUncompressed, numRows, isValidityBuffer), buffers_(std::move(buffers)) {} static arrow::Result> merge(std::unique_ptr source, std::unique_ptr append, arrow::MemoryPool* pool); arrow::Status serialize(arrow::io::OutputStream* outputStream) override; arrow::Result> readBufferAt(uint32_t index) override; arrow::Result> toBlockPayload(Payload::Type payloadType, arrow::MemoryPool* pool, arrow::util::Codec* codec); int64_t getBufferSize() const; arrow::Status copyBuffers(arrow::MemoryPool* pool); private: std::vector> buffers_; }; class UncompressedDiskBlockPayload : public Payload { public: UncompressedDiskBlockPayload( Type type, uint32_t numRows, const std::vector* isValidityBuffer, arrow::io::InputStream*& inputStream, uint64_t rawSize, arrow::MemoryPool* pool, arrow::util::Codec* codec); arrow::Result> readBufferAt(uint32_t index) override; arrow::Status serialize(arrow::io::OutputStream* outputStream) override; private: arrow::io::InputStream*& inputStream_; uint64_t rawSize_; arrow::MemoryPool* pool_; arrow::util::Codec* codec_; uint32_t readPos_{0}; arrow::Result> readUncompressedBuffer(); }; class CompressedDiskBlockPayload : public Payload { public: CompressedDiskBlockPayload( uint32_t numRows, const std::vector* isValidityBuffer, arrow::io::InputStream*& inputStream, uint64_t rawSize, arrow::MemoryPool* pool); arrow::Status serialize(arrow::io::OutputStream* outputStream) override; arrow::Result> readBufferAt(uint32_t index) override; private: arrow::io::InputStream*& inputStream_; uint64_t rawSize_; arrow::MemoryPool* pool_; }; } // namespace gluten