def makeFromJava()

in sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala [90:212]


  def makeFromJava(dataType: DataType): Any => Any = dataType match {
    case BooleanType => (obj: Any) => nullSafeConvert(obj) {
      case b: Boolean => b
    }

    case ByteType => (obj: Any) => nullSafeConvert(obj) {
      case c: Byte => c
      case c: Short => c.toByte
      case c: Int => c.toByte
      case c: Long => c.toByte
    }

    case ShortType => (obj: Any) => nullSafeConvert(obj) {
      case c: Byte => c.toShort
      case c: Short => c
      case c: Int => c.toShort
      case c: Long => c.toShort
    }

    case IntegerType => (obj: Any) => nullSafeConvert(obj) {
      case c: Byte => c.toInt
      case c: Short => c.toInt
      case c: Int => c
      case c: Long => c.toInt
    }

    case LongType => (obj: Any) => nullSafeConvert(obj) {
      case c: Byte => c.toLong
      case c: Short => c.toLong
      case c: Int => c.toLong
      case c: Long => c
    }

    case FloatType => (obj: Any) => nullSafeConvert(obj) {
      case c: Float => c
      case c: Double => c.toFloat
    }

    case DoubleType => (obj: Any) => nullSafeConvert(obj) {
      case c: Float => c.toDouble
      case c: Double => c
    }

    case dt: DecimalType => (obj: Any) => nullSafeConvert(obj) {
      case c: java.math.BigDecimal => Decimal(c, dt.precision, dt.scale)
    }

    case DateType => (obj: Any) => nullSafeConvert(obj) {
      case c: Int => c
    }

    case TimestampType | TimestampNTZType | _: DayTimeIntervalType => (obj: Any) =>
      nullSafeConvert(obj) {
        case c: Long => c
        // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
        case c: Int => c.toLong
      }

    case _: StringType => (obj: Any) => nullSafeConvert(obj) {
      case _ => UTF8String.fromString(obj.toString)
    }

    case BinaryType => (obj: Any) => nullSafeConvert(obj) {
      case c: String => c.getBytes(StandardCharsets.UTF_8)
      case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
    }

    case ArrayType(elementType, _) =>
      val elementFromJava = makeFromJava(elementType)

      (obj: Any) => nullSafeConvert(obj) {
        case c: java.util.List[_] =>
          new GenericArrayData(c.asScala.map { e => elementFromJava(e) }.toArray)
        case c if c.getClass.isArray =>
          new GenericArrayData(c.asInstanceOf[Array[_]].map(e => elementFromJava(e)))
      }

    case MapType(keyType, valueType, _) =>
      val keyFromJava = makeFromJava(keyType)
      val valueFromJava = makeFromJava(valueType)

      (obj: Any) => nullSafeConvert(obj) {
        case javaMap: java.util.Map[_, _] =>
          ArrayBasedMapData(
            javaMap,
            (key: Any) => keyFromJava(key),
            (value: Any) => valueFromJava(value))
      }

    case StructType(fields) =>
      val fieldsFromJava = fields.map(f => makeFromJava(f.dataType))

      (obj: Any) => nullSafeConvert(obj) {
        case c if c.getClass.isArray =>
          val array = c.asInstanceOf[Array[_]]
          if (array.length != fields.length) {
            throw new SparkIllegalArgumentException(
              errorClass = "STRUCT_ARRAY_LENGTH_MISMATCH",
              messageParameters = Map(
                "expected" -> fields.length.toString,
                "actual" -> array.length.toString))
          }

          val row = new GenericInternalRow(fields.length)
          var i = 0
          while (i < fields.length) {
            row(i) = fieldsFromJava(i)(array(i))
            i += 1
          }
          row
      }

    case udt: UserDefinedType[_] => makeFromJava(udt.sqlType)

    case VariantType => (obj: Any) => nullSafeConvert(obj) {
      case s: java.util.HashMap[_, _] =>
        new VariantVal(
          s.get("value").asInstanceOf[Array[Byte]], s.get("metadata").asInstanceOf[Array[Byte]]
        )
    }

    case other => (obj: Any) => nullSafeConvert(obj)(PartialFunction.empty)
  }