Sources/SparkConnect/Extension.swift (279 lines of code) (raw):
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//
import Foundation
extension String {
/// Get a `Plan` instance from a string.
var toSparkConnectPlan: Plan {
var sql = Spark_Connect_SQL()
sql.query = self
var relation = Relation()
relation.sql = sql
var plan = Plan()
plan.opType = Plan.OneOf_OpType.root(relation)
return plan
}
func toSparkConnectPlan(_ posArguments: [Sendable]) -> Plan {
var sql = Spark_Connect_SQL()
sql.query = self
sql.posArguments = posArguments.map {
var literal = ExpressionLiteral()
switch $0 {
case let value as Bool:
literal.boolean = value
case let value as Int8:
literal.byte = Int32(value)
case let value as Int16:
literal.short = Int32(value)
case let value as Int32:
literal.integer = value
case let value as Int64:
literal.long = value
case let value as Int:
literal.long = Int64(value)
case let value as String:
literal.string = value
default:
literal.string = $0 as! String
}
var expr = Spark_Connect_Expression()
expr.literal = literal
return expr
}
var relation = Relation()
relation.sql = sql
var plan = Plan()
plan.opType = Plan.OneOf_OpType.root(relation)
return plan
}
func toSparkConnectPlan(_ namedArguments: [String: Sendable]) -> Plan {
var sql = Spark_Connect_SQL()
sql.query = self
sql.namedArguments = namedArguments.mapValues { value in
var literal = ExpressionLiteral()
switch value {
case let value as Bool:
literal.boolean = value
case let value as Int8:
literal.byte = Int32(value)
case let value as Int16:
literal.short = Int32(value)
case let value as Int32:
literal.integer = value
case let value as Int64:
literal.long = value
case let value as Int:
literal.long = Int64(value)
case let value as String:
literal.string = value
default:
literal.string = value as! String
}
var expr = Spark_Connect_Expression()
expr.literal = literal
return expr
}
var relation = Relation()
relation.sql = sql
var plan = Plan()
plan.opType = Plan.OneOf_OpType.root(relation)
return plan
}
/// Get a `UserContext` instance from a string.
var toUserContext: UserContext {
var context = UserContext()
context.userID = self
context.userName = self
return context
}
/// Get a `KeyValue` instance by using a string as the key.
var toKeyValue: KeyValue {
var keyValue = KeyValue()
keyValue.key = self
return keyValue
}
var toUnresolvedAttribute: UnresolvedAttribute {
var attribute = UnresolvedAttribute()
attribute.unparsedIdentifier = self
return attribute
}
var toExpressionString: ExpressionString {
var expression = ExpressionString()
expression.expression = self
return expression
}
var toExplainMode: ExplainMode {
let mode = switch self {
case "codegen": ExplainMode.codegen
case "cost": ExplainMode.cost
case "extended": ExplainMode.extended
case "formatted": ExplainMode.formatted
case "simple": ExplainMode.simple
default: ExplainMode.simple
}
return mode
}
var toSaveMode: SaveMode {
return switch self.lowercased() {
case "append": SaveMode.append
case "overwrite": SaveMode.overwrite
case "error": SaveMode.errorIfExists
case "errorIfExists": SaveMode.errorIfExists
case "ignore": SaveMode.ignore
default: SaveMode.errorIfExists
}
}
var toJoinType: JoinType {
return switch self.lowercased() {
case "inner": JoinType.inner
case "cross": JoinType.cross
case "outer", "full", "fullouter", "full_outer": JoinType.fullOuter
case "left", "leftouter", "left_outer": JoinType.leftOuter
case "right", "rightouter", "right_outer": JoinType.rightOuter
case "semi", "leftsemi", "left_semi": JoinType.leftSemi
case "anti", "leftanti", "left_anti": JoinType.leftAnti
default: JoinType.inner
}
}
var toGroupType: GroupType {
return switch self.lowercased() {
case "groupby": .groupby
case "rollup": .rollup
case "cube": .cube
case "pivot": .pivot
case "groupingsets": .groupingSets
default: .UNRECOGNIZED(-1)
}
}
}
extension [String: String] {
/// Get an array of `KeyValue` from `[String: String]`.
var toSparkConnectKeyValue: [KeyValue] {
var array = [KeyValue]()
for keyValue in self {
var kv = KeyValue()
kv.key = keyValue.key
kv.value = keyValue.value
array.append(kv)
}
return array
}
}
extension Data {
/// Get an `Int32` value from unsafe 4 bytes.
var int32: Int32 { withUnsafeBytes({ $0.load(as: Int32.self) }) }
}
extension SparkSession: Equatable {
public static func == (lhs: SparkSession, rhs: SparkSession) -> Bool {
return lhs.sessionID == rhs.sessionID
}
}
extension YearMonthInterval {
func fieldToString(_ field: Int32) throws -> String {
return switch field {
case 0: "year"
case 1: "month"
default:
throw SparkConnectError.InvalidTypeException
}
}
func toString() throws -> String {
let startFieldName = try fieldToString(self.startField)
let endFieldName = try fieldToString(self.endField)
let interval = if startFieldName == endFieldName {
"interval \(startFieldName)"
} else if startFieldName < endFieldName {
"interval \(startFieldName) to \(endFieldName)"
} else {
throw SparkConnectError.InvalidTypeException
}
return interval
}
}
extension DayTimeInterval {
func fieldToString(_ field: Int32) throws -> String {
return switch field {
case 0: "day"
case 1: "hour"
case 2: "minute"
case 3: "second"
default:
throw SparkConnectError.InvalidTypeException
}
}
func toString() throws -> String {
let startFieldName = try fieldToString(self.startField)
let endFieldName = try fieldToString(self.endField)
let interval = if startFieldName == endFieldName {
"interval \(startFieldName)"
} else if startFieldName < endFieldName {
"interval \(startFieldName) to \(endFieldName)"
} else {
throw SparkConnectError.InvalidTypeException
}
return interval
}
}
extension MapType {
func toString() throws -> String {
return "map<\(try self.keyType.simpleString),\(try self.valueType.simpleString)>"
}
}
extension StructType {
func toString() throws -> String {
let fieldTypes = try fields.map { "\($0.name):\(try $0.dataType.simpleString)" }
return "struct<\(fieldTypes.joined(separator: ","))>"
}
}
extension DataType {
var simpleString: String {
get throws {
return switch self.kind {
case .null:
"void"
case .binary:
"binary"
case .boolean:
"boolean"
case .byte:
"tinyint"
case .short:
"smallint"
case .integer:
"int"
case .long:
"bigint"
case .float:
"float"
case .double:
"double"
case .decimal:
"decimal(\(self.decimal.precision),\(self.decimal.scale))"
case .string:
"string"
case .char:
"char"
case .varChar:
"varchar"
case .date:
"date"
case .timestamp:
"timestamp"
case .timestampNtz:
"timestamp_ntz"
case .calendarInterval:
"interval"
case .yearMonthInterval:
try self.yearMonthInterval.toString()
case .dayTimeInterval:
try self.dayTimeInterval.toString()
case .array:
"array<\(try self.array.elementType.simpleString)>"
case .struct:
try self.struct.toString()
case .map:
try self.map.toString()
case .variant:
"variant"
case .udt:
self.udt.type
case .unparsed:
self.unparsed.dataTypeString
default:
throw SparkConnectError.InvalidTypeException
}
}
}
}