cpp/core/memory/ColumnarBatch.cc (164 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <memory> #include "arrow/c/bridge.h" #include "arrow/c/helpers.h" #include "arrow/record_batch.h" #include "memory/MemoryManager.h" #include "operators/writer/ArrowWriter.h" #include "utils/ArrowStatus.h" #include "utils/exception.h" #include "ColumnarBatch.h" namespace gluten { ColumnarBatch::ColumnarBatch(int32_t numColumns, int32_t numRows) : numColumns_(numColumns), numRows_(numRows), exportNanos_(0) {} int32_t ColumnarBatch::numColumns() const { return numColumns_; } int32_t ColumnarBatch::numRows() const { return numRows_; } int64_t ColumnarBatch::getExportNanos() const { return exportNanos_; } std::pair<char*, int> ColumnarBatch::getRowBytes(int32_t rowId) const { throw gluten::GlutenException("Not implemented getRowBytes for ColumnarBatch"); } std::ostream& operator<<(std::ostream& os, const ColumnarBatch& columnarBatch) { return os << "NumColumns: " << std::to_string(columnarBatch.numColumns()) << "NumRows: " << std::to_string(columnarBatch.numRows()); } std::shared_ptr<ColumnarBatch> createZeroColumnBatch(int32_t numRows) { return std::make_shared<ArrowColumnarBatch>(arrow::RecordBatch::Make( std::make_shared<arrow::Schema>(std::vector<std::shared_ptr<arrow::Field>>()), numRows, std::vector<std::shared_ptr<arrow::Array>>())); } ArrowColumnarBatch::ArrowColumnarBatch(std::shared_ptr<arrow::RecordBatch> batch) : ColumnarBatch(batch->num_columns(), batch->num_rows()), batch_(std::move(batch)) {} std::string ArrowColumnarBatch::getType() const { return "arrow"; } int64_t ArrowColumnarBatch::numBytes() { throw gluten::GlutenException("Not implemented GetBytes for ArrowColumnarBatch"); } arrow::RecordBatch* ArrowColumnarBatch::getRecordBatch() const { return batch_.get(); } std::shared_ptr<ArrowSchema> ArrowColumnarBatch::exportArrowSchema() { auto cSchema = std::make_shared<ArrowSchema>(); GLUTEN_THROW_NOT_OK(arrow::ExportSchema(*batch_->schema(), cSchema.get())); return cSchema; } std::shared_ptr<ArrowArray> ArrowColumnarBatch::exportArrowArray() { auto cArray = std::make_shared<ArrowArray>(); GLUTEN_THROW_NOT_OK(arrow::ExportRecordBatch(*batch_, cArray.get())); return cArray; } std::pair<char*, int> ArrowColumnarBatch::getRowBytes(int32_t rowId) const { throw gluten::GlutenException("Not implemented getRowBytes for ArrowColumnarBatch"); } ArrowCStructColumnarBatch::ArrowCStructColumnarBatch( std::unique_ptr<ArrowSchema> cSchema, std::unique_ptr<ArrowArray> cArray) : ColumnarBatch(cArray->n_children, cArray->length) { ArrowSchemaMove(cSchema.get(), cSchema_.get()); ArrowArrayMove(cArray.get(), cArray_.get()); } ArrowCStructColumnarBatch::~ArrowCStructColumnarBatch() { ArrowSchemaRelease(cSchema_.get()); ArrowArrayRelease(cArray_.get()); } std::string ArrowCStructColumnarBatch::getType() const { return "arrow_array"; } int64_t ArrowCStructColumnarBatch::numBytes() { int64_t bytes = cArray_->n_buffers; for (int64_t i = 0; i < cArray_->n_children; ++i) { bytes += cArray_->children[i]->n_buffers; } return bytes; } std::shared_ptr<ArrowSchema> ArrowCStructColumnarBatch::exportArrowSchema() { return cSchema_; } std::shared_ptr<ArrowArray> ArrowCStructColumnarBatch::exportArrowArray() { return cArray_; } std::pair<char*, int> ArrowCStructColumnarBatch::getRowBytes(int32_t rowId) const { throw gluten::GlutenException("Not implemented getRowBytes for ArrowCStructColumnarBatch"); } std::shared_ptr<ColumnarBatch> CompositeColumnarBatch::create(std::vector<std::shared_ptr<ColumnarBatch>> batches) { int32_t numRows = -1; int32_t numColumns = 0; for (const auto& batch : batches) { if (numRows == -1) { numRows = batch->numRows(); } else if (batch->numRows() != numRows) { throw GlutenException("Mismatched row counts among the input batches during creating CompositeColumnarBatch"); } numColumns += batch->numColumns(); } return std::shared_ptr<ColumnarBatch>(new CompositeColumnarBatch(numColumns, numRows, std::move(batches))); } std::string CompositeColumnarBatch::getType() const { return "composite"; } int64_t CompositeColumnarBatch::numBytes() { if (compositeBatch_) { return compositeBatch_->numBytes(); } else { int64_t numBytes = 0L; for (const auto& batch : batches_) { numBytes += batch->numBytes(); } return numBytes; } } std::shared_ptr<ArrowArray> CompositeColumnarBatch::exportArrowArray() { ensureUnderlyingBatchCreated(); return compositeBatch_->exportArrowArray(); } std::shared_ptr<ArrowSchema> CompositeColumnarBatch::exportArrowSchema() { ensureUnderlyingBatchCreated(); return compositeBatch_->exportArrowSchema(); } const std::vector<std::shared_ptr<ColumnarBatch>>& CompositeColumnarBatch::getBatches() const { return batches_; } std::pair<char*, int> CompositeColumnarBatch::getRowBytes(int32_t rowId) const { throw gluten::GlutenException("Not implemented getRowBytes for CompositeColumnarBatch"); } CompositeColumnarBatch::CompositeColumnarBatch( int32_t numColumns, int32_t numRows, std::vector<std::shared_ptr<ColumnarBatch>> batches) : ColumnarBatch(numColumns, numRows) { this->batches_ = std::move(batches); } void CompositeColumnarBatch::ensureUnderlyingBatchCreated() { if (compositeBatch_ != nullptr) { return; } std::vector<std::shared_ptr<arrow::RecordBatch>> arrowBatches; for (const auto& batch : batches_) { auto cSchema = batch->exportArrowSchema(); auto cArray = batch->exportArrowArray(); auto arrowBatch = gluten::arrowGetOrThrow(arrow::ImportRecordBatch(cArray.get(), cSchema.get())); arrowBatches.push_back(std::move(arrowBatch)); } std::vector<std::shared_ptr<arrow::Field>> fields; std::vector<std::shared_ptr<arrow::ArrayData>> arrays; for (const auto& batch : arrowBatches) { if (batch->schema()->metadata() != nullptr) { throw gluten::GlutenException("Schema metadata not allowed"); } for (const auto& field : batch->schema()->fields()) { fields.push_back(field); } for (const auto& col : batch->column_data()) { arrays.push_back(col); } } compositeBatch_ = std::make_shared<ArrowColumnarBatch>( arrow::RecordBatch::Make(std::make_shared<arrow::Schema>(fields), numRows(), arrays)); } } // namespace gluten