Sources/SparkConnect/ArrowDecoder.swift (287 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. import Foundation /// @nodoc public class ArrowDecoder: Decoder { var rbIndex: UInt = 0 var singleRBCol: Int = 0 public var codingPath: [CodingKey] = [] public var userInfo: [CodingUserInfoKey: Any] = [:] public let rb: RecordBatch public let nameToCol: [String: ArrowArrayHolder] public let columns: [ArrowArrayHolder] public init(_ decoder: ArrowDecoder) { self.userInfo = decoder.userInfo self.codingPath = decoder.codingPath self.rb = decoder.rb self.columns = decoder.columns self.nameToCol = decoder.nameToCol self.rbIndex = decoder.rbIndex } public init(_ rb: RecordBatch) { self.rb = rb var colMapping = [String: ArrowArrayHolder]() var columns = [ArrowArrayHolder]() for index in 0..<self.rb.schema.fields.count { let field = self.rb.schema.fields[index] columns.append(self.rb.column(index)) colMapping[field.name] = self.rb.column(index) } self.columns = columns self.nameToCol = colMapping } public func decode<T: Decodable, U: Decodable>(_ type: [T: U].Type) throws -> [T: U] { var output = [T: U]() if rb.columnCount != 2 { throw ArrowError.invalid("RecordBatch column count of 2 is required to decode to map") } for index in 0..<rb.length { self.rbIndex = index self.singleRBCol = 0 let key = try T.init(from: self) self.singleRBCol = 1 let value = try U.init(from: self) output[key] = value } self.singleRBCol = 0 return output } public func decode<T: Decodable>(_ type: T.Type) throws -> [T] { var output = [T]() for index in 0..<rb.length { self.rbIndex = index output.append(try type.init(from: self)) } return output } public func container<Key>( keyedBy type: Key.Type ) -> KeyedDecodingContainer<Key> where Key: CodingKey { let container = ArrowKeyedDecoding<Key>(self, codingPath: codingPath) return KeyedDecodingContainer(container) } public func unkeyedContainer() -> UnkeyedDecodingContainer { return ArrowUnkeyedDecoding(self, codingPath: codingPath) } public func singleValueContainer() -> SingleValueDecodingContainer { return ArrowSingleValueDecoding(self, codingPath: codingPath) } func getCol(_ name: String) throws -> AnyArray { guard let col = self.nameToCol[name] else { throw ArrowError.invalid("Column for key \"\(name)\" not found") } return col.array } func getCol(_ index: Int) throws -> AnyArray { if index >= self.columns.count { throw ArrowError.outOfBounds(index: Int64(index)) } return self.columns[index].array } func doDecode<T>(_ key: CodingKey) throws -> T? { let array: AnyArray = try self.getCol(key.stringValue) return array.asAny(self.rbIndex) as? T } func doDecode<T>(_ col: Int) throws -> T? { let array: AnyArray = try self.getCol(col) return array.asAny(self.rbIndex) as? T } func isNull(_ key: CodingKey) throws -> Bool { let array: AnyArray = try self.getCol(key.stringValue) return array.asAny(self.rbIndex) == nil } func isNull(_ col: Int) throws -> Bool { let array: AnyArray = try self.getCol(col) return array.asAny(self.rbIndex) == nil } } private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer { var codingPath: [CodingKey] var count: Int? = 0 var isAtEnd: Bool = false var currentIndex: Int = 0 let decoder: ArrowDecoder init(_ decoder: ArrowDecoder, codingPath: [CodingKey]) { self.decoder = decoder self.codingPath = codingPath self.count = self.decoder.columns.count } mutating func increment() { self.currentIndex += 1 self.isAtEnd = self.currentIndex >= self.count! } mutating func decodeNil() throws -> Bool { defer { increment() } return try self.decoder.isNull(self.currentIndex) } mutating func decode<T>(_ type: T.Type) throws -> T where T: Decodable { if type == Int8?.self || type == Int16?.self || type == Int32?.self || type == Int64?.self || type == UInt8?.self || type == UInt16?.self || type == UInt32?.self || type == UInt64?.self || type == String?.self || type == Double?.self || type == Float?.self || type == Date?.self || type == Bool?.self || type == Bool.self || type == Int8.self || type == Int16.self || type == Int32.self || type == Int64.self || type == UInt8.self || type == UInt16.self || type == UInt32.self || type == UInt64.self || type == String.self || type == Double.self || type == Float.self || type == Date.self { defer { increment() } return try self.decoder.doDecode(self.currentIndex)! } else { throw ArrowError.invalid("Type \(type) is currently not supported") } } func nestedContainer<NestedKey>( keyedBy type: NestedKey.Type ) throws -> KeyedDecodingContainer<NestedKey> where NestedKey: CodingKey { throw ArrowError.invalid("Nested decoding is currently not supported.") } func nestedUnkeyedContainer() throws -> UnkeyedDecodingContainer { throw ArrowError.invalid("Nested decoding is currently not supported.") } func superDecoder() throws -> Decoder { throw ArrowError.invalid("super decoding is currently not supported.") } } private struct ArrowKeyedDecoding<Key: CodingKey>: KeyedDecodingContainerProtocol { var codingPath = [CodingKey]() var allKeys = [Key]() let decoder: ArrowDecoder init(_ decoder: ArrowDecoder, codingPath: [CodingKey]) { self.decoder = decoder self.codingPath = codingPath } func contains(_ key: Key) -> Bool { return self.decoder.nameToCol.keys.contains(key.stringValue) } func decodeNil(forKey key: Key) throws -> Bool { try self.decoder.isNull(key) } func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool { return try self.decoder.doDecode(key)! } func decode(_ type: String.Type, forKey key: Key) throws -> String { return try self.decoder.doDecode(key)! } func decode(_ type: Double.Type, forKey key: Key) throws -> Double { return try self.decoder.doDecode(key)! } func decode(_ type: Float.Type, forKey key: Key) throws -> Float { return try self.decoder.doDecode(key)! } func decode(_ type: Int.Type, forKey key: Key) throws -> Int { throw ArrowError.invalid( "Int type is not supported (please use Int8, Int16, Int32 or Int64)") } func decode(_ type: Int8.Type, forKey key: Key) throws -> Int8 { return try self.decoder.doDecode(key)! } func decode(_ type: Int16.Type, forKey key: Key) throws -> Int16 { return try self.decoder.doDecode(key)! } func decode(_ type: Int32.Type, forKey key: Key) throws -> Int32 { return try self.decoder.doDecode(key)! } func decode(_ type: Int64.Type, forKey key: Key) throws -> Int64 { return try self.decoder.doDecode(key)! } func decode(_ type: UInt.Type, forKey key: Key) throws -> UInt { throw ArrowError.invalid( "UInt type is not supported (please use UInt8, UInt16, UInt32 or UInt64)") } func decode(_ type: UInt8.Type, forKey key: Key) throws -> UInt8 { return try self.decoder.doDecode(key)! } func decode(_ type: UInt16.Type, forKey key: Key) throws -> UInt16 { return try self.decoder.doDecode(key)! } func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 { return try self.decoder.doDecode(key)! } func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 { return try self.decoder.doDecode(key)! } func decode<T>(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable { if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self { return try self.decoder.doDecode(key)! } else { throw ArrowError.invalid("Type \(type) is currently not supported") } } func nestedContainer<NestedKey>( keyedBy type: NestedKey.Type, forKey key: Key ) throws -> KeyedDecodingContainer<NestedKey> where NestedKey: CodingKey { throw ArrowError.invalid("Nested decoding is currently not supported.") } func nestedUnkeyedContainer(forKey key: Key) throws -> UnkeyedDecodingContainer { throw ArrowError.invalid("Nested decoding is currently not supported.") } func superDecoder() throws -> Decoder { throw ArrowError.invalid("super decoding is currently not supported.") } func superDecoder(forKey key: Key) throws -> Decoder { throw ArrowError.invalid("super decoding is currently not supported.") } } private struct ArrowSingleValueDecoding: SingleValueDecodingContainer { var codingPath = [CodingKey]() let decoder: ArrowDecoder init(_ decoder: ArrowDecoder, codingPath: [CodingKey]) { self.decoder = decoder self.codingPath = codingPath } func decodeNil() -> Bool { do { return try self.decoder.isNull(self.decoder.singleRBCol) } catch { return false } } func decode(_ type: Bool.Type) throws -> Bool { return try self.decoder.doDecode(self.decoder.singleRBCol)! } func decode(_ type: String.Type) throws -> String { return try self.decoder.doDecode(self.decoder.singleRBCol)! } func decode(_ type: Double.Type) throws -> Double { return try self.decoder.doDecode(self.decoder.singleRBCol)! } func decode(_ type: Float.Type) throws -> Float { return try self.decoder.doDecode(self.decoder.singleRBCol)! } func decode(_ type: Int.Type) throws -> Int { throw ArrowError.invalid( "Int type is not supported (please use Int8, Int16, Int32 or Int64)") } func decode(_ type: Int8.Type) throws -> Int8 { return try self.decoder.doDecode(self.decoder.singleRBCol)! } func decode(_ type: Int16.Type) throws -> Int16 { return try self.decoder.doDecode(self.decoder.singleRBCol)! } func decode(_ type: Int32.Type) throws -> Int32 { return try self.decoder.doDecode(self.decoder.singleRBCol)! } func decode(_ type: Int64.Type) throws -> Int64 { return try self.decoder.doDecode(self.decoder.singleRBCol)! } func decode(_ type: UInt.Type) throws -> UInt { throw ArrowError.invalid( "UInt type is not supported (please use UInt8, UInt16, UInt32 or UInt64)") } func decode(_ type: UInt8.Type) throws -> UInt8 { return try self.decoder.doDecode(self.decoder.singleRBCol)! } func decode(_ type: UInt16.Type) throws -> UInt16 { return try self.decoder.doDecode(self.decoder.singleRBCol)! } func decode(_ type: UInt32.Type) throws -> UInt32 { return try self.decoder.doDecode(self.decoder.singleRBCol)! } func decode(_ type: UInt64.Type) throws -> UInt64 { return try self.decoder.doDecode(self.decoder.singleRBCol)! } func decode<T>(_ type: T.Type) throws -> T where T: Decodable { if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self { return try self.decoder.doDecode(self.decoder.singleRBCol)! } else { throw ArrowError.invalid("Type \(type) is currently not supported") } } }