Sources/SparkConnect/ArrowTable.swift (158 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. import Foundation /// @nodoc public class ArrowColumn { public let field: ArrowField fileprivate let dataHolder: ChunkedArrayHolder public var type: ArrowType { return self.dataHolder.type } public var length: UInt { return self.dataHolder.length } public var nullCount: UInt { return self.dataHolder.nullCount } public func data<T>() -> ChunkedArray<T> { return (self.dataHolder.holder as! ChunkedArray<T>) // swiftlint:disable:this force_cast } public var name: String { return field.name } public init(_ field: ArrowField, chunked: ChunkedArrayHolder) { self.field = field self.dataHolder = chunked } } /// @nodoc public class ArrowTable { public let schema: ArrowSchema public var columnCount: UInt { return UInt(self.columns.count) } public let rowCount: UInt public let columns: [ArrowColumn] init(_ schema: ArrowSchema, columns: [ArrowColumn]) { self.schema = schema self.columns = columns self.rowCount = columns[0].length } public static func from(recordBatches: [RecordBatch]) -> Result<ArrowTable, ArrowError> { if recordBatches.isEmpty { return .failure(.arrayHasNoElements) } var holders = [[ArrowArrayHolder]]() let schema = recordBatches[0].schema for recordBatch in recordBatches { for index in 0..<schema.fields.count { if holders.count <= index { holders.append([ArrowArrayHolder]()) } holders[index].append(recordBatch.columns[index]) } } let builder = ArrowTable.Builder() for index in 0..<schema.fields.count { switch makeArrowColumn(schema.fields[index], holders: holders[index]) { case .success(let column): builder.addColumn(column) case .failure(let error): return .failure(error) } } return .success(builder.finish()) } private static func makeArrowColumn( _ field: ArrowField, holders: [ArrowArrayHolder] ) -> Result<ArrowColumn, ArrowError> { do { return .success(try holders[0].getArrowColumn(field, holders)) } catch { return .failure(.runtimeError("\(error)")) } } public class Builder { let schemaBuilder = ArrowSchema.Builder() var columns = [ArrowColumn]() public init() {} @discardableResult public func addColumn<T>(_ fieldName: String, arrowArray: ArrowArray<T>) throws -> Builder { return self.addColumn(fieldName, chunked: try ChunkedArray([arrowArray])) } @discardableResult public func addColumn<T>(_ fieldName: String, chunked: ChunkedArray<T>) -> Builder { let field = ArrowField(fieldName, type: chunked.type, isNullable: chunked.nullCount != 0) self.schemaBuilder.addField(field) self.columns.append(ArrowColumn(field, chunked: ChunkedArrayHolder(chunked))) return self } @discardableResult public func addColumn<T>(_ field: ArrowField, arrowArray: ArrowArray<T>) throws -> Builder { self.schemaBuilder.addField(field) let holder = ChunkedArrayHolder(try ChunkedArray([arrowArray])) self.columns.append(ArrowColumn(field, chunked: holder)) return self } @discardableResult public func addColumn<T>(_ field: ArrowField, chunked: ChunkedArray<T>) -> Builder { self.schemaBuilder.addField(field) self.columns.append(ArrowColumn(field, chunked: ChunkedArrayHolder(chunked))) return self } @discardableResult public func addColumn(_ column: ArrowColumn) -> Builder { self.schemaBuilder.addField(column.field) self.columns.append(column) return self } public func finish() -> ArrowTable { return ArrowTable(self.schemaBuilder.finish(), columns: self.columns) } } } public class RecordBatch { public let schema: ArrowSchema public var columnCount: UInt { return UInt(self.columns.count) } public let columns: [ArrowArrayHolder] public let length: UInt public init(_ schema: ArrowSchema, columns: [ArrowArrayHolder]) { self.schema = schema self.columns = columns self.length = columns[0].length } public class Builder { let schemaBuilder = ArrowSchema.Builder() var columns = [ArrowArrayHolder]() public init() {} @discardableResult public func addColumn(_ fieldName: String, arrowArray: ArrowArrayHolder) -> Builder { let field = ArrowField( fieldName, type: arrowArray.type, isNullable: arrowArray.nullCount != 0) self.schemaBuilder.addField(field) self.columns.append(arrowArray) return self } @discardableResult public func addColumn(_ field: ArrowField, arrowArray: ArrowArrayHolder) -> Builder { self.schemaBuilder.addField(field) self.columns.append(arrowArray) return self } public func finish() -> Result<RecordBatch, ArrowError> { if columns.count > 0 { let columnLength = columns[0].length for column in columns { if column.length != columnLength { // swiftlint:disable:this for_where return .failure(.runtimeError("Columns have different sizes")) } } } return .success(RecordBatch(self.schemaBuilder.finish(), columns: self.columns)) } } public func data<T>(for columnIndex: Int) -> ArrowArray<T> { let arrayHolder = column(columnIndex) return (arrayHolder.array as! ArrowArray<T>) // swiftlint:disable:this force_cast } public func anyData(for columnIndex: Int) -> AnyArray { let arrayHolder = column(columnIndex) return arrayHolder.array } public func column(_ index: Int) -> ArrowArrayHolder { return self.columns[index] } public func column(_ name: String) -> ArrowArrayHolder? { if let index = self.schema.fieldIndex(name) { return self.columns[index] } return nil } }