in sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala [90:212]
def makeFromJava(dataType: DataType): Any => Any = dataType match {
case BooleanType => (obj: Any) => nullSafeConvert(obj) {
case b: Boolean => b
}
case ByteType => (obj: Any) => nullSafeConvert(obj) {
case c: Byte => c
case c: Short => c.toByte
case c: Int => c.toByte
case c: Long => c.toByte
}
case ShortType => (obj: Any) => nullSafeConvert(obj) {
case c: Byte => c.toShort
case c: Short => c
case c: Int => c.toShort
case c: Long => c.toShort
}
case IntegerType => (obj: Any) => nullSafeConvert(obj) {
case c: Byte => c.toInt
case c: Short => c.toInt
case c: Int => c
case c: Long => c.toInt
}
case LongType => (obj: Any) => nullSafeConvert(obj) {
case c: Byte => c.toLong
case c: Short => c.toLong
case c: Int => c.toLong
case c: Long => c
}
case FloatType => (obj: Any) => nullSafeConvert(obj) {
case c: Float => c
case c: Double => c.toFloat
}
case DoubleType => (obj: Any) => nullSafeConvert(obj) {
case c: Float => c.toDouble
case c: Double => c
}
case dt: DecimalType => (obj: Any) => nullSafeConvert(obj) {
case c: java.math.BigDecimal => Decimal(c, dt.precision, dt.scale)
}
case DateType => (obj: Any) => nullSafeConvert(obj) {
case c: Int => c
}
case TimestampType | TimestampNTZType | _: DayTimeIntervalType => (obj: Any) =>
nullSafeConvert(obj) {
case c: Long => c
// Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
case c: Int => c.toLong
}
case _: StringType => (obj: Any) => nullSafeConvert(obj) {
case _ => UTF8String.fromString(obj.toString)
}
case BinaryType => (obj: Any) => nullSafeConvert(obj) {
case c: String => c.getBytes(StandardCharsets.UTF_8)
case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
}
case ArrayType(elementType, _) =>
val elementFromJava = makeFromJava(elementType)
(obj: Any) => nullSafeConvert(obj) {
case c: java.util.List[_] =>
new GenericArrayData(c.asScala.map { e => elementFromJava(e) }.toArray)
case c if c.getClass.isArray =>
new GenericArrayData(c.asInstanceOf[Array[_]].map(e => elementFromJava(e)))
}
case MapType(keyType, valueType, _) =>
val keyFromJava = makeFromJava(keyType)
val valueFromJava = makeFromJava(valueType)
(obj: Any) => nullSafeConvert(obj) {
case javaMap: java.util.Map[_, _] =>
ArrayBasedMapData(
javaMap,
(key: Any) => keyFromJava(key),
(value: Any) => valueFromJava(value))
}
case StructType(fields) =>
val fieldsFromJava = fields.map(f => makeFromJava(f.dataType))
(obj: Any) => nullSafeConvert(obj) {
case c if c.getClass.isArray =>
val array = c.asInstanceOf[Array[_]]
if (array.length != fields.length) {
throw new SparkIllegalArgumentException(
errorClass = "STRUCT_ARRAY_LENGTH_MISMATCH",
messageParameters = Map(
"expected" -> fields.length.toString,
"actual" -> array.length.toString))
}
val row = new GenericInternalRow(fields.length)
var i = 0
while (i < fields.length) {
row(i) = fieldsFromJava(i)(array(i))
i += 1
}
row
}
case udt: UserDefinedType[_] => makeFromJava(udt.sqlType)
case VariantType => (obj: Any) => nullSafeConvert(obj) {
case s: java.util.HashMap[_, _] =>
new VariantVal(
s.get("value").asInstanceOf[Array[Byte]], s.get("metadata").asInstanceOf[Array[Byte]]
)
}
case other => (obj: Any) => nullSafeConvert(obj)(PartialFunction.empty)
}