Sources/SparkConnect/ArrowReader.swift (314 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
import FlatBuffers
import Foundation
let FILEMARKER = "ARROW1"
let CONTINUATIONMARKER = -1
/// @nodoc
public class ArrowReader { // swiftlint:disable:this type_body_length
private class RecordBatchData {
let schema: org_apache_arrow_flatbuf_Schema
let recordBatch: org_apache_arrow_flatbuf_RecordBatch
private var fieldIndex: Int32 = 0
private var nodeIndex: Int32 = 0
private var bufferIndex: Int32 = 0
init(
_ recordBatch: org_apache_arrow_flatbuf_RecordBatch,
schema: org_apache_arrow_flatbuf_Schema
) {
self.recordBatch = recordBatch
self.schema = schema
}
func nextNode() -> org_apache_arrow_flatbuf_FieldNode? {
if nodeIndex >= self.recordBatch.nodesCount { return nil }
defer { nodeIndex += 1 }
return self.recordBatch.nodes(at: nodeIndex)
}
func nextBuffer() -> org_apache_arrow_flatbuf_Buffer? {
if bufferIndex >= self.recordBatch.buffersCount { return nil }
defer { bufferIndex += 1 }
return self.recordBatch.buffers(at: bufferIndex)
}
func nextField() -> org_apache_arrow_flatbuf_Field? {
if fieldIndex >= self.schema.fieldsCount { return nil }
defer { fieldIndex += 1 }
return self.schema.fields(at: fieldIndex)
}
func isDone() -> Bool {
return nodeIndex >= self.recordBatch.nodesCount
}
}
private struct DataLoadInfo {
let fileData: Data
let messageOffset: Int64
var batchData: RecordBatchData
}
public class ArrowReaderResult {
fileprivate var messageSchema: org_apache_arrow_flatbuf_Schema?
public var schema: ArrowSchema?
public var batches = [RecordBatch]()
}
public init() {}
private func loadSchema(_ schema: org_apache_arrow_flatbuf_Schema) -> Result<
ArrowSchema, ArrowError
> {
let builder = ArrowSchema.Builder()
for index in 0..<schema.fieldsCount {
let field = schema.fields(at: index)!
let fieldType = findArrowType(field)
if fieldType.info == ArrowType.ArrowUnknown {
return .failure(.unknownType("Unsupported field type found: \(field.typeType)"))
}
let arrowField = ArrowField(field.name!, type: fieldType, isNullable: field.nullable)
builder.addField(arrowField)
}
return .success(builder.finish())
}
private func loadStructData(
_ loadInfo: DataLoadInfo,
field: org_apache_arrow_flatbuf_Field
)
-> Result<ArrowArrayHolder, ArrowError>
{
guard let node = loadInfo.batchData.nextNode() else {
return .failure(.invalid("Node not found"))
}
guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
return .failure(.invalid("Null buffer not found"))
}
let nullLength = UInt(ceil(Double(node.length) / 8))
let arrowNullBuffer = makeBuffer(
nullBuffer, fileData: loadInfo.fileData,
length: nullLength, messageOffset: loadInfo.messageOffset)
var children = [ArrowData]()
for index in 0..<field.childrenCount {
let childField = field.children(at: index)!
switch loadField(loadInfo, field: childField) {
case .success(let holder):
children.append(holder.array.arrowData)
case .failure(let error):
return .failure(error)
}
}
return makeArrayHolder(
field, buffers: [arrowNullBuffer],
nullCount: UInt(node.nullCount), children: children,
rbLength: UInt(loadInfo.batchData.recordBatch.length))
}
private func loadPrimitiveData(
_ loadInfo: DataLoadInfo,
field: org_apache_arrow_flatbuf_Field
)
-> Result<ArrowArrayHolder, ArrowError>
{
guard let node = loadInfo.batchData.nextNode() else {
return .failure(.invalid("Node not found"))
}
guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
return .failure(.invalid("Null buffer not found"))
}
guard let valueBuffer = loadInfo.batchData.nextBuffer() else {
return .failure(.invalid("Value buffer not found"))
}
let nullLength = UInt(ceil(Double(node.length) / 8))
let arrowNullBuffer = makeBuffer(
nullBuffer, fileData: loadInfo.fileData,
length: nullLength, messageOffset: loadInfo.messageOffset)
let arrowValueBuffer = makeBuffer(
valueBuffer, fileData: loadInfo.fileData,
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
return makeArrayHolder(
field, buffers: [arrowNullBuffer, arrowValueBuffer],
nullCount: UInt(node.nullCount), children: nil,
rbLength: UInt(loadInfo.batchData.recordBatch.length))
}
private func loadVariableData(
_ loadInfo: DataLoadInfo,
field: org_apache_arrow_flatbuf_Field
)
-> Result<ArrowArrayHolder, ArrowError>
{
guard let node = loadInfo.batchData.nextNode() else {
return .failure(.invalid("Node not found"))
}
guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
return .failure(.invalid("Null buffer not found"))
}
guard let offsetBuffer = loadInfo.batchData.nextBuffer() else {
return .failure(.invalid("Offset buffer not found"))
}
guard let valueBuffer = loadInfo.batchData.nextBuffer() else {
return .failure(.invalid("Value buffer not found"))
}
let nullLength = UInt(ceil(Double(node.length) / 8))
let arrowNullBuffer = makeBuffer(
nullBuffer, fileData: loadInfo.fileData,
length: nullLength, messageOffset: loadInfo.messageOffset)
let arrowOffsetBuffer = makeBuffer(
offsetBuffer, fileData: loadInfo.fileData,
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
let arrowValueBuffer = makeBuffer(
valueBuffer, fileData: loadInfo.fileData,
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
return makeArrayHolder(
field, buffers: [arrowNullBuffer, arrowOffsetBuffer, arrowValueBuffer],
nullCount: UInt(node.nullCount), children: nil,
rbLength: UInt(loadInfo.batchData.recordBatch.length))
}
private func loadField(
_ loadInfo: DataLoadInfo,
field: org_apache_arrow_flatbuf_Field
)
-> Result<ArrowArrayHolder, ArrowError>
{
if isNestedType(field.typeType) {
return loadStructData(loadInfo, field: field)
} else if isFixedPrimitive(field.typeType) {
return loadPrimitiveData(loadInfo, field: field)
} else {
return loadVariableData(loadInfo, field: field)
}
}
private func loadRecordBatch(
_ recordBatch: org_apache_arrow_flatbuf_RecordBatch,
schema: org_apache_arrow_flatbuf_Schema,
arrowSchema: ArrowSchema,
data: Data,
messageEndOffset: Int64
) -> Result<RecordBatch, ArrowError> {
var columns: [ArrowArrayHolder] = []
let batchData = RecordBatchData(recordBatch, schema: schema)
let loadInfo = DataLoadInfo(
fileData: data,
messageOffset: messageEndOffset,
batchData: batchData)
while !batchData.isDone() {
guard let field = batchData.nextField() else {
return .failure(.invalid("Field not found"))
}
let result = loadField(loadInfo, field: field)
switch result {
case .success(let holder):
columns.append(holder)
case .failure(let error):
return .failure(error)
}
}
return .success(RecordBatch(arrowSchema, columns: columns))
}
public func fromStream( // swiftlint:disable:this function_body_length
_ fileData: Data,
useUnalignedBuffers: Bool = false
) -> Result<ArrowReaderResult, ArrowError> {
let footerLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: fileData.count - 4, as: Int32.self)
}
let result = ArrowReaderResult()
let footerStartOffset = fileData.count - Int(footerLength + 4)
let footerData = fileData[footerStartOffset...]
let footerBuffer = ByteBuffer(
data: footerData,
allowReadingUnalignedBuffers: useUnalignedBuffers)
let footer = org_apache_arrow_flatbuf_Footer.getRootAsFooter(bb: footerBuffer)
let schemaResult = loadSchema(footer.schema!)
switch schemaResult {
case .success(let schema):
result.schema = schema
case .failure(let error):
return .failure(error)
}
for index in 0..<footer.recordBatchesCount {
let recordBatch = footer.recordBatches(at: index)!
var messageLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: Int32.self)
}
var messageOffset: Int64 = 1
if messageLength == CONTINUATIONMARKER {
messageOffset += 1
messageLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(
fromByteOffset: Int(recordBatch.offset + Int64(MemoryLayout<Int32>.size)),
as: Int32.self)
}
}
let messageStartOffset =
recordBatch.offset + (Int64(MemoryLayout<Int32>.size) * messageOffset)
let messageEndOffset = messageStartOffset + Int64(messageLength)
let recordBatchData = fileData[messageStartOffset..<messageEndOffset]
let mbb = ByteBuffer(
data: recordBatchData,
allowReadingUnalignedBuffers: useUnalignedBuffers)
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: mbb)
switch message.headerType {
case .recordbatch:
do {
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
let recordBatch = try loadRecordBatch(
rbMessage,
schema: footer.schema!,
arrowSchema: result.schema!,
data: fileData,
messageEndOffset: messageEndOffset
).get()
result.batches.append(recordBatch)
} catch let error {
return .failure(error)
}
default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}
return .success(result)
}
public func fromFile(_ fileURL: URL) -> Result<ArrowReaderResult, ArrowError> {
do {
let fileData = try Data(contentsOf: fileURL)
if !validateFileData(fileData) {
return .failure(.ioError("Not a valid arrow file."))
}
let markerLength = FILEMARKER.utf8.count
let footerLengthEnd = Int(fileData.count - markerLength)
let data = fileData[..<(footerLengthEnd)]
return fromStream(data)
} catch {
return .failure(.unknownError("Error loading file: \(error)"))
}
}
static public func makeArrowReaderResult() -> ArrowReaderResult {
return ArrowReaderResult()
}
public func fromMessage(
_ dataHeader: Data,
dataBody: Data,
result: ArrowReaderResult,
useUnalignedBuffers: Bool = false
) -> Result<Void, ArrowError> {
let mbb = ByteBuffer(
data: dataHeader,
allowReadingUnalignedBuffers: useUnalignedBuffers)
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: mbb)
switch message.headerType {
case .schema:
let sMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
switch loadSchema(sMessage) {
case .success(let schema):
result.schema = schema
result.messageSchema = sMessage
return .success(())
case .failure(let error):
return .failure(error)
}
case .recordbatch:
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
do {
let recordBatch = try loadRecordBatch(
rbMessage, schema: result.messageSchema!, arrowSchema: result.schema!,
data: dataBody, messageEndOffset: 0
).get()
result.batches.append(recordBatch)
return .success(())
} catch let error {
return .failure(error)
}
default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}
}