in sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala [243:508]
private[arrow] def serializerFor[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = {
(encoder, v) match {
case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) =>
new FieldSerializer[Boolean, BitVector](v) {
override def set(index: Int, value: Boolean): Unit =
vector.setSafe(index, if (value) 1 else 0)
}
case (PrimitiveByteEncoder | BoxedByteEncoder, v: TinyIntVector) =>
new FieldSerializer[Byte, TinyIntVector](v) {
override def set(index: Int, value: Byte): Unit = vector.setSafe(index, value)
}
case (PrimitiveShortEncoder | BoxedShortEncoder, v: SmallIntVector) =>
new FieldSerializer[Short, SmallIntVector](v) {
override def set(index: Int, value: Short): Unit = vector.setSafe(index, value)
}
case (PrimitiveIntEncoder | BoxedIntEncoder, v: IntVector) =>
new FieldSerializer[Int, IntVector](v) {
override def set(index: Int, value: Int): Unit = vector.setSafe(index, value)
}
case (PrimitiveLongEncoder | BoxedLongEncoder, v: BigIntVector) =>
new FieldSerializer[Long, BigIntVector](v) {
override def set(index: Int, value: Long): Unit = vector.setSafe(index, value)
}
case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: Float4Vector) =>
new FieldSerializer[Float, Float4Vector](v) {
override def set(index: Int, value: Float): Unit = vector.setSafe(index, value)
}
case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: Float8Vector) =>
new FieldSerializer[Double, Float8Vector](v) {
override def set(index: Int, value: Double): Unit = vector.setSafe(index, value)
}
case (NullEncoder, v: NullVector) =>
new FieldSerializer[Unit, NullVector](v) {
override def set(index: Int, value: Unit): Unit = vector.setNull(index)
}
case (StringEncoder, v: VarCharVector) =>
new FieldSerializer[String, VarCharVector](v) {
override def set(index: Int, value: String): Unit = setString(v, index, value)
}
case (StringEncoder, v: LargeVarCharVector) =>
new FieldSerializer[String, LargeVarCharVector](v) {
override def set(index: Int, value: String): Unit = setString(v, index, value)
}
case (JavaEnumEncoder(_), v: VarCharVector) =>
new FieldSerializer[Enum[_], VarCharVector](v) {
override def set(index: Int, value: Enum[_]): Unit = setString(v, index, value.name())
}
case (JavaEnumEncoder(_), v: LargeVarCharVector) =>
new FieldSerializer[Enum[_], LargeVarCharVector](v) {
override def set(index: Int, value: Enum[_]): Unit = setString(v, index, value.name())
}
case (ScalaEnumEncoder(_, _), v: VarCharVector) =>
new FieldSerializer[Enumeration#Value, VarCharVector](v) {
override def set(index: Int, value: Enumeration#Value): Unit =
setString(v, index, value.toString)
}
case (ScalaEnumEncoder(_, _), v: LargeVarCharVector) =>
new FieldSerializer[Enumeration#Value, LargeVarCharVector](v) {
override def set(index: Int, value: Enumeration#Value): Unit =
setString(v, index, value.toString)
}
case (BinaryEncoder, v: VarBinaryVector) =>
new FieldSerializer[Array[Byte], VarBinaryVector](v) {
override def set(index: Int, value: Array[Byte]): Unit = vector.setSafe(index, value)
}
case (BinaryEncoder, v: LargeVarBinaryVector) =>
new FieldSerializer[Array[Byte], LargeVarBinaryVector](v) {
override def set(index: Int, value: Array[Byte]): Unit = vector.setSafe(index, value)
}
case (SparkDecimalEncoder(_), v: DecimalVector) =>
new FieldSerializer[Decimal, DecimalVector](v) {
override def set(index: Int, value: Decimal): Unit =
setDecimal(vector, index, value.toJavaBigDecimal)
}
case (ScalaDecimalEncoder(_), v: DecimalVector) =>
new FieldSerializer[BigDecimal, DecimalVector](v) {
override def set(index: Int, value: BigDecimal): Unit =
setDecimal(vector, index, value.bigDecimal)
}
case (JavaDecimalEncoder(_, false), v: DecimalVector) =>
new FieldSerializer[JBigDecimal, DecimalVector](v) {
override def set(index: Int, value: JBigDecimal): Unit =
setDecimal(vector, index, value)
}
case (JavaDecimalEncoder(_, true), v: DecimalVector) =>
new FieldSerializer[Any, DecimalVector](v) {
override def set(index: Int, value: Any): Unit = {
val decimal = value match {
case j: JBigDecimal => j
case d: BigDecimal => d.bigDecimal
case k: BigInt => new JBigDecimal(k.bigInteger)
case l: JBigInteger => new JBigDecimal(l)
case d: Decimal => d.toJavaBigDecimal
}
setDecimal(vector, index, decimal)
}
}
case (ScalaBigIntEncoder, v: DecimalVector) =>
new FieldSerializer[BigInt, DecimalVector](v) {
override def set(index: Int, value: BigInt): Unit =
setDecimal(vector, index, new JBigDecimal(value.bigInteger))
}
case (JavaBigIntEncoder, v: DecimalVector) =>
new FieldSerializer[JBigInteger, DecimalVector](v) {
override def set(index: Int, value: JBigInteger): Unit =
setDecimal(vector, index, new JBigDecimal(value))
}
case (DayTimeIntervalEncoder, v: DurationVector) =>
new FieldSerializer[Duration, DurationVector](v) {
override def set(index: Int, value: Duration): Unit =
vector.setSafe(index, SparkIntervalUtils.durationToMicros(value))
}
case (YearMonthIntervalEncoder, v: IntervalYearVector) =>
new FieldSerializer[Period, IntervalYearVector](v) {
override def set(index: Int, value: Period): Unit =
vector.setSafe(index, SparkIntervalUtils.periodToMonths(value))
}
case (DateEncoder(true) | LocalDateEncoder(true), v: DateDayVector) =>
new FieldSerializer[Any, DateDayVector](v) {
override def set(index: Int, value: Any): Unit =
vector.setSafe(index, SparkDateTimeUtils.anyToDays(value))
}
case (DateEncoder(false), v: DateDayVector) =>
new FieldSerializer[java.sql.Date, DateDayVector](v) {
override def set(index: Int, value: java.sql.Date): Unit =
vector.setSafe(index, SparkDateTimeUtils.fromJavaDate(value))
}
case (LocalDateEncoder(false), v: DateDayVector) =>
new FieldSerializer[LocalDate, DateDayVector](v) {
override def set(index: Int, value: LocalDate): Unit =
vector.setSafe(index, SparkDateTimeUtils.localDateToDays(value))
}
case (TimestampEncoder(true) | InstantEncoder(true), v: TimeStampMicroTZVector) =>
new FieldSerializer[Any, TimeStampMicroTZVector](v) {
override def set(index: Int, value: Any): Unit =
vector.setSafe(index, SparkDateTimeUtils.anyToMicros(value))
}
case (TimestampEncoder(false), v: TimeStampMicroTZVector) =>
new FieldSerializer[java.sql.Timestamp, TimeStampMicroTZVector](v) {
override def set(index: Int, value: java.sql.Timestamp): Unit =
vector.setSafe(index, SparkDateTimeUtils.fromJavaTimestamp(value))
}
case (InstantEncoder(false), v: TimeStampMicroTZVector) =>
new FieldSerializer[Instant, TimeStampMicroTZVector](v) {
override def set(index: Int, value: Instant): Unit =
vector.setSafe(index, SparkDateTimeUtils.instantToMicros(value))
}
case (LocalDateTimeEncoder, v: TimeStampMicroVector) =>
new FieldSerializer[LocalDateTime, TimeStampMicroVector](v) {
override def set(index: Int, value: LocalDateTime): Unit =
vector.setSafe(index, SparkDateTimeUtils.localDateTimeToMicros(value))
}
case (OptionEncoder(value), v) =>
new Serializer {
private[this] val delegate: Serializer = serializerFor(value, v)
override def write(index: Int, value: Any): Unit = value match {
case Some(value) => delegate.write(index, value)
case _ => delegate.write(index, null)
}
}
case (ArrayEncoder(element, _), v: ListVector) =>
val elementSerializer = serializerFor(element, v.getDataVector)
val toIterator = { array: Any =>
array.asInstanceOf[Array[_]].iterator
}
new ArraySerializer(v, toIterator, elementSerializer)
case (IterableEncoder(tag, element, _, lenient), v: ListVector) =>
val elementSerializer = serializerFor(element, v.getDataVector)
val toIterator: Any => Iterator[_] = if (lenient) {
{
case i: scala.collection.Iterable[_] => i.iterator
case l: java.util.List[_] => l.iterator().asScala
case a: Array[_] => a.iterator
case o => unsupportedCollectionType(o.getClass)
}
} else if (isSubClass(Classes.ITERABLE, tag)) { v =>
v.asInstanceOf[scala.collection.Iterable[_]].iterator
} else if (isSubClass(Classes.JLIST, tag)) { v =>
v.asInstanceOf[java.util.List[_]].iterator().asScala
} else {
unsupportedCollectionType(tag.runtimeClass)
}
new ArraySerializer(v, toIterator, elementSerializer)
case (MapEncoder(tag, key, value, _), v: MapVector) =>
val structVector = v.getDataVector.asInstanceOf[StructVector]
val extractor = if (isSubClass(classOf[scala.collection.Map[_, _]], tag)) { (v: Any) =>
v.asInstanceOf[scala.collection.Map[_, _]].iterator
} else if (isSubClass(classOf[JMap[_, _]], tag)) { (v: Any) =>
v.asInstanceOf[JMap[Any, Any]].asScala.iterator
} else {
unsupportedCollectionType(tag.runtimeClass)
}
val structSerializer = new StructSerializer(
structVector,
new StructFieldSerializer(
extractKey,
serializerFor(key, structVector.getChild(MapVector.KEY_NAME))) ::
new StructFieldSerializer(
extractValue,
serializerFor(value, structVector.getChild(MapVector.VALUE_NAME))) :: Nil)
new ArraySerializer(v, extractor, structSerializer)
case (ProductEncoder(tag, fields, _), StructVectors(struct, vectors)) =>
if (isSubClass(classOf[Product], tag)) {
structSerializerFor(fields, struct, vectors) { (_, i) => p =>
p.asInstanceOf[Product].productElement(i)
}
} else if (isSubClass(classOf[DefinedByConstructorParams], tag)) {
structSerializerFor(fields, struct, vectors) { (field, _) =>
val getter = methodLookup.findVirtual(
tag.runtimeClass,
field.name,
MethodType.methodType(field.enc.clsTag.runtimeClass))
o => getter.invoke(o)
}
} else {
unsupportedCollectionType(tag.runtimeClass)
}
case (RowEncoder(fields), StructVectors(struct, vectors)) =>
structSerializerFor(fields, struct, vectors) { (_, i) => r => r.asInstanceOf[Row].get(i) }
case (VariantEncoder, StructVectors(struct, vectors)) =>
assert(vectors.exists(_.getName == "value"))
assert(
vectors.exists(field =>
field.getName == "metadata" && field.getField.getMetadata
.containsKey("variant") && field.getField.getMetadata.get("variant") == "true"))
new StructSerializer(
struct,
Seq(
new StructFieldSerializer(
extractor = (v: Any) => v.asInstanceOf[VariantVal].getValue,
serializerFor(BinaryEncoder, struct.getChild("value"))),
new StructFieldSerializer(
extractor = (v: Any) => v.asInstanceOf[VariantVal].getMetadata,
serializerFor(BinaryEncoder, struct.getChild("metadata")))))
case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) =>
structSerializerFor(fields, struct, vectors) { (field, _) =>
val getter = methodLookup.findVirtual(
tag.runtimeClass,
field.readMethod.get,
MethodType.methodType(field.enc.clsTag.runtimeClass))
o => getter.invoke(o)
}
case (TransformingEncoder(_, encoder, provider, _), v) =>
new Serializer {
private[this] val codec = provider().asInstanceOf[Codec[Any, Any]]
private[this] val delegate: Serializer = serializerFor(encoder, v)
override def write(index: Int, value: Any): Unit =
delegate.write(index, codec.encode(value))
}
case (CalendarIntervalEncoder | _: UDTEncoder[_], _) =>
throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType)
case _ =>
throw new RuntimeException(s"Unsupported Encoder($encoder)/Vector($v) combination.")
}
}