Sources/SparkConnect/ArrowWriter.swift (304 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. import FlatBuffers import Foundation /// @nodoc public protocol DataWriter { var count: Int { get } func append(_ data: Data) } /// @nodoc public class ArrowWriter { // swiftlint:disable:this type_body_length public class InMemDataWriter: DataWriter { public private(set) var data: Data public var count: Int { return data.count } public init(_ data: Data) { self.data = data } convenience init() { self.init(Data()) } public func append(_ data: Data) { self.data.append(data) } } public class FileDataWriter: DataWriter { private var handle: FileHandle private var currentSize: Int = 0 public var count: Int { return currentSize } public init(_ handle: FileHandle) { self.handle = handle } public func append(_ data: Data) { self.handle.write(data) self.currentSize += data.count } } public class Info { public let type: org_apache_arrow_flatbuf_MessageHeader public let schema: ArrowSchema public let batches: [RecordBatch] public init( _ type: org_apache_arrow_flatbuf_MessageHeader, schema: ArrowSchema, batches: [RecordBatch] ) { self.type = type self.schema = schema self.batches = batches } public convenience init(_ type: org_apache_arrow_flatbuf_MessageHeader, schema: ArrowSchema) { self.init(type, schema: schema, batches: [RecordBatch]()) } } public init() {} private func writeField(_ fbb: inout FlatBufferBuilder, field: ArrowField) -> Result< Offset, ArrowError > { let nameOffset = fbb.create(string: field.name) let fieldTypeOffsetResult = toFBType(&fbb, arrowType: field.type) let startOffset = org_apache_arrow_flatbuf_Field.startField(&fbb) org_apache_arrow_flatbuf_Field.add(name: nameOffset, &fbb) org_apache_arrow_flatbuf_Field.add(nullable: field.isNullable, &fbb) switch toFBTypeEnum(field.type) { case .success(let type): org_apache_arrow_flatbuf_Field.add(typeType: type, &fbb) case .failure(let error): return .failure(error) } switch fieldTypeOffsetResult { case .success(let offset): org_apache_arrow_flatbuf_Field.add(type: offset, &fbb) return .success(org_apache_arrow_flatbuf_Field.endField(&fbb, start: startOffset)) case .failure(let error): return .failure(error) } } private func writeSchema(_ fbb: inout FlatBufferBuilder, schema: ArrowSchema) -> Result< Offset, ArrowError > { var fieldOffsets = [Offset]() for field in schema.fields { switch writeField(&fbb, field: field) { case .success(let offset): fieldOffsets.append(offset) case .failure(let error): return .failure(error) } } let fieldsOffset: Offset = fbb.createVector(ofOffsets: fieldOffsets) let schemaOffset = org_apache_arrow_flatbuf_Schema.createSchema( &fbb, endianness: .little, fieldsVectorOffset: fieldsOffset) return .success(schemaOffset) } private func writeRecordBatches( _ writer: inout DataWriter, batches: [RecordBatch] ) -> Result<[org_apache_arrow_flatbuf_Block], ArrowError> { var rbBlocks = [org_apache_arrow_flatbuf_Block]() for batch in batches { let startIndex = writer.count switch writeRecordBatch(batch: batch) { case .success(let rbResult): withUnsafeBytes(of: rbResult.1.o.littleEndian) { writer.append(Data($0)) } writer.append(rbResult.0) switch writeRecordBatchData(&writer, batch: batch) { case .success: rbBlocks.append( org_apache_arrow_flatbuf_Block( offset: Int64(startIndex), metaDataLength: Int32(0), bodyLength: Int64(rbResult.1.o))) case .failure(let error): return .failure(error) } case .failure(let error): return .failure(error) } } return .success(rbBlocks) } private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> { let schema = batch.schema var fbb = FlatBufferBuilder() // write out field nodes var fieldNodeOffsets = [Offset]() fbb.startVector( schema.fields.count, elementSize: MemoryLayout<org_apache_arrow_flatbuf_FieldNode>.size) for index in (0..<schema.fields.count).reversed() { let column = batch.column(index) let fieldNode = org_apache_arrow_flatbuf_FieldNode( length: Int64(column.length), nullCount: Int64(column.nullCount)) fieldNodeOffsets.append(fbb.create(struct: fieldNode)) } let nodeOffset = fbb.endVector(len: schema.fields.count) // write out buffers var buffers = [org_apache_arrow_flatbuf_Buffer]() var bufferOffset = Int(0) for index in 0..<batch.schema.fields.count { let column = batch.column(index) let colBufferDataSizes = column.getBufferDataSizes() for var bufferDataSize in colBufferDataSizes { bufferDataSize = getPadForAlignment(bufferDataSize) let buffer = org_apache_arrow_flatbuf_Buffer( offset: Int64(bufferOffset), length: Int64(bufferDataSize)) buffers.append(buffer) bufferOffset += bufferDataSize } } org_apache_arrow_flatbuf_RecordBatch.startVectorOfBuffers(batch.schema.fields.count, in: &fbb) for buffer in buffers.reversed() { fbb.create(struct: buffer) } let batchBuffersOffset = fbb.endVector(len: buffers.count) let startRb = org_apache_arrow_flatbuf_RecordBatch.startRecordBatch(&fbb) org_apache_arrow_flatbuf_RecordBatch.addVectorOf(nodes: nodeOffset, &fbb) org_apache_arrow_flatbuf_RecordBatch.addVectorOf(buffers: batchBuffersOffset, &fbb) org_apache_arrow_flatbuf_RecordBatch.add(length: Int64(batch.length), &fbb) let recordBatchOffset = org_apache_arrow_flatbuf_RecordBatch.endRecordBatch( &fbb, start: startRb) let bodySize = Int64(bufferOffset) let startMessage = org_apache_arrow_flatbuf_Message.startMessage(&fbb) org_apache_arrow_flatbuf_Message.add(version: .max, &fbb) org_apache_arrow_flatbuf_Message.add(bodyLength: Int64(bodySize), &fbb) org_apache_arrow_flatbuf_Message.add(headerType: .recordbatch, &fbb) org_apache_arrow_flatbuf_Message.add(header: recordBatchOffset, &fbb) let messageOffset = org_apache_arrow_flatbuf_Message.endMessage(&fbb, start: startMessage) fbb.finish(offset: messageOffset) return .success((fbb.data, Offset(offset: UInt32(fbb.data.count)))) } private func writeRecordBatchData(_ writer: inout DataWriter, batch: RecordBatch) -> Result< Bool, ArrowError > { for index in 0..<batch.schema.fields.count { let column = batch.column(index) let colBufferData = column.getBufferData() for var bufferData in colBufferData { addPadForAlignment(&bufferData) writer.append(bufferData) } } return .success(true) } private func writeFooter( schema: ArrowSchema, rbBlocks: [org_apache_arrow_flatbuf_Block] ) -> Result<Data, ArrowError> { var fbb: FlatBufferBuilder = FlatBufferBuilder() switch writeSchema(&fbb, schema: schema) { case .success(let schemaOffset): fbb.startVector( rbBlocks.count, elementSize: MemoryLayout<org_apache_arrow_flatbuf_Block>.size) for blkInfo in rbBlocks.reversed() { fbb.create(struct: blkInfo) } let rbBlkEnd = fbb.endVector(len: rbBlocks.count) let footerStartOffset = org_apache_arrow_flatbuf_Footer.startFooter(&fbb) org_apache_arrow_flatbuf_Footer.add(schema: schemaOffset, &fbb) org_apache_arrow_flatbuf_Footer.addVectorOf(recordBatches: rbBlkEnd, &fbb) let footerOffset = org_apache_arrow_flatbuf_Footer.endFooter(&fbb, start: footerStartOffset) fbb.finish(offset: footerOffset) case .failure(let error): return .failure(error) } return .success(fbb.data) } private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result< Bool, ArrowError > { var fbb: FlatBufferBuilder = FlatBufferBuilder() switch writeSchema(&fbb, schema: info.schema) { case .success(let schemaOffset): fbb.finish(offset: schemaOffset) writer.append(fbb.data) case .failure(let error): return .failure(error) } switch writeRecordBatches(&writer, batches: info.batches) { case .success(let rbBlocks): switch writeFooter(schema: info.schema, rbBlocks: rbBlocks) { case .success(let footerData): fbb.finish(offset: Offset(offset: fbb.buffer.size)) let footerOffset = writer.count writer.append(footerData) addPadForAlignment(&writer) withUnsafeBytes(of: Int32(0).littleEndian) { writer.append(Data($0)) } let footerDiff = (UInt32(writer.count) - UInt32(footerOffset)) withUnsafeBytes(of: footerDiff.littleEndian) { writer.append(Data($0)) } case .failure(let error): return .failure(error) } case .failure(let error): return .failure(error) } return .success(true) } public func toStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> { var writer: any DataWriter = InMemDataWriter() switch writeStream(&writer, info: info) { case .success: if let memWriter = writer as? InMemDataWriter { return .success(memWriter.data) } else { return .failure(.invalid("Unable to cast writer")) } case .failure(let error): return .failure(error) } } public func toFile(_ fileName: URL, info: ArrowWriter.Info) -> Result<Bool, ArrowError> { do { try Data().write(to: fileName) } catch { return .failure(.ioError("\(error)")) } let fileHandle = FileHandle(forUpdatingAtPath: fileName.path)! defer { fileHandle.closeFile() } var markerData = FILEMARKER.data(using: .utf8)! addPadForAlignment(&markerData) var writer: any DataWriter = FileDataWriter(fileHandle) writer.append(FILEMARKER.data(using: .utf8)!) switch writeStream(&writer, info: info) { case .success: writer.append(FILEMARKER.data(using: .utf8)!) case .failure(let error): return .failure(error) } return .success(true) } public func toMessage(_ batch: RecordBatch) -> Result<[Data], ArrowError> { var writer: any DataWriter = InMemDataWriter() switch writeRecordBatch(batch: batch) { case .success(let message): writer.append(message.0) addPadForAlignment(&writer) var dataWriter: any DataWriter = InMemDataWriter() switch writeRecordBatchData(&dataWriter, batch: batch) { case .success: return .success([ (writer as! InMemDataWriter).data, // swiftlint:disable:this force_cast (dataWriter as! InMemDataWriter).data, // swiftlint:disable:this force_cast ]) case .failure(let error): return .failure(error) } case .failure(let error): return .failure(error) } } public func toMessage(_ schema: ArrowSchema) -> Result<Data, ArrowError> { var schemaSize: Int32 = 0 var fbb = FlatBufferBuilder() switch writeSchema(&fbb, schema: schema) { case .success(let schemaOffset): schemaSize = Int32(schemaOffset.o) case .failure(let error): return .failure(error) } let startMessage = org_apache_arrow_flatbuf_Message.startMessage(&fbb) org_apache_arrow_flatbuf_Message.add(bodyLength: Int64(0), &fbb) org_apache_arrow_flatbuf_Message.add(headerType: .schema, &fbb) org_apache_arrow_flatbuf_Message.add(header: Offset(offset: UOffset(schemaSize)), &fbb) org_apache_arrow_flatbuf_Message.add(version: .max, &fbb) let messageOffset = org_apache_arrow_flatbuf_Message.endMessage(&fbb, start: startMessage) fbb.finish(offset: messageOffset) return .success(fbb.data) } }