in sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala [90:405]
private[arrow] def deserializerFor(
encoder: AgnosticEncoder[_],
data: AnyRef,
timeZoneId: String): Deserializer[Any] = {
(encoder, data) match {
case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: FieldVector) =>
new LeafFieldDeserializer[Boolean](encoder, v, timeZoneId) {
override def value(i: Int): Boolean = reader.getBoolean(i)
}
case (PrimitiveByteEncoder | BoxedByteEncoder, v: FieldVector) =>
new LeafFieldDeserializer[Byte](encoder, v, timeZoneId) {
override def value(i: Int): Byte = reader.getByte(i)
}
case (PrimitiveShortEncoder | BoxedShortEncoder, v: FieldVector) =>
new LeafFieldDeserializer[Short](encoder, v, timeZoneId) {
override def value(i: Int): Short = reader.getShort(i)
}
case (PrimitiveIntEncoder | BoxedIntEncoder, v: FieldVector) =>
new LeafFieldDeserializer[Int](encoder, v, timeZoneId) {
override def value(i: Int): Int = reader.getInt(i)
}
case (PrimitiveLongEncoder | BoxedLongEncoder, v: FieldVector) =>
new LeafFieldDeserializer[Long](encoder, v, timeZoneId) {
override def value(i: Int): Long = reader.getLong(i)
}
case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: FieldVector) =>
new LeafFieldDeserializer[Float](encoder, v, timeZoneId) {
override def value(i: Int): Float = reader.getFloat(i)
}
case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: FieldVector) =>
new LeafFieldDeserializer[Double](encoder, v, timeZoneId) {
override def value(i: Int): Double = reader.getDouble(i)
}
case (NullEncoder, _: FieldVector) =>
new Deserializer[Any] {
def get(i: Int): Any = null
}
case (StringEncoder, v: FieldVector) =>
new LeafFieldDeserializer[String](encoder, v, timeZoneId) {
override def value(i: Int): String = reader.getString(i)
}
case (JavaEnumEncoder(tag), v: FieldVector) =>
// It would be nice if we can get Enum.valueOf working...
val valueOf = methodLookup.findStatic(
tag.runtimeClass,
"valueOf",
MethodType.methodType(tag.runtimeClass, classOf[String]))
new LeafFieldDeserializer[Enum[_]](encoder, v, timeZoneId) {
override def value(i: Int): Enum[_] = {
valueOf.invoke(reader.getString(i)).asInstanceOf[Enum[_]]
}
}
case (ScalaEnumEncoder(parent, _), v: FieldVector) =>
val mirror = scala.reflect.runtime.currentMirror
val module = mirror.classSymbol(parent).module.asModule
val enumeration = mirror.reflectModule(module).instance.asInstanceOf[Enumeration]
new LeafFieldDeserializer[Enumeration#Value](encoder, v, timeZoneId) {
override def value(i: Int): Enumeration#Value = {
enumeration.withName(reader.getString(i))
}
}
case (BinaryEncoder, v: FieldVector) =>
new LeafFieldDeserializer[Array[Byte]](encoder, v, timeZoneId) {
override def value(i: Int): Array[Byte] = reader.getBytes(i)
}
case (SparkDecimalEncoder(_), v: FieldVector) =>
new LeafFieldDeserializer[Decimal](encoder, v, timeZoneId) {
override def value(i: Int): Decimal = reader.getDecimal(i)
}
case (ScalaDecimalEncoder(_), v: FieldVector) =>
new LeafFieldDeserializer[BigDecimal](encoder, v, timeZoneId) {
override def value(i: Int): BigDecimal = reader.getScalaDecimal(i)
}
case (JavaDecimalEncoder(_, _), v: FieldVector) =>
new LeafFieldDeserializer[JBigDecimal](encoder, v, timeZoneId) {
override def value(i: Int): JBigDecimal = reader.getJavaDecimal(i)
}
case (ScalaBigIntEncoder, v: FieldVector) =>
new LeafFieldDeserializer[BigInt](encoder, v, timeZoneId) {
override def value(i: Int): BigInt = reader.getScalaBigInt(i)
}
case (JavaBigIntEncoder, v: FieldVector) =>
new LeafFieldDeserializer[JBigInteger](encoder, v, timeZoneId) {
override def value(i: Int): JBigInteger = reader.getJavaBigInt(i)
}
case (DayTimeIntervalEncoder, v: FieldVector) =>
new LeafFieldDeserializer[Duration](encoder, v, timeZoneId) {
override def value(i: Int): Duration = reader.getDuration(i)
}
case (YearMonthIntervalEncoder, v: FieldVector) =>
new LeafFieldDeserializer[Period](encoder, v, timeZoneId) {
override def value(i: Int): Period = reader.getPeriod(i)
}
case (DateEncoder(_), v: FieldVector) =>
new LeafFieldDeserializer[java.sql.Date](encoder, v, timeZoneId) {
override def value(i: Int): java.sql.Date = reader.getDate(i)
}
case (LocalDateEncoder(_), v: FieldVector) =>
new LeafFieldDeserializer[LocalDate](encoder, v, timeZoneId) {
override def value(i: Int): LocalDate = reader.getLocalDate(i)
}
case (TimestampEncoder(_), v: FieldVector) =>
new LeafFieldDeserializer[java.sql.Timestamp](encoder, v, timeZoneId) {
override def value(i: Int): java.sql.Timestamp = reader.getTimestamp(i)
}
case (InstantEncoder(_), v: FieldVector) =>
new LeafFieldDeserializer[Instant](encoder, v, timeZoneId) {
override def value(i: Int): Instant = reader.getInstant(i)
}
case (LocalDateTimeEncoder, v: FieldVector) =>
new LeafFieldDeserializer[LocalDateTime](encoder, v, timeZoneId) {
override def value(i: Int): LocalDateTime = reader.getLocalDateTime(i)
}
case (OptionEncoder(value), v) =>
val deserializer = deserializerFor(value, v, timeZoneId)
new Deserializer[Any] {
override def get(i: Int): Any = Option(deserializer.get(i))
}
case (ArrayEncoder(element, _), v: ListVector) =>
val deserializer = deserializerFor(element, v.getDataVector, timeZoneId)
new VectorFieldDeserializer[AnyRef, ListVector](v) {
def value(i: Int): AnyRef = getArray(vector, i, deserializer)(element.clsTag)
}
case (IterableEncoder(tag, element, _, _), v: ListVector) =>
val deserializer = deserializerFor(element, v.getDataVector, timeZoneId)
if (isSubClass(Classes.MUTABLE_ARRAY_SEQ, tag)) {
// mutable ArraySeq is a bit special because we need to use an array of the element type.
// Some parts of our codebase (unfortunately) rely on this for type inference on results.
new VectorFieldDeserializer[mutable.ArraySeq[Any], ListVector](v) {
def value(i: Int): mutable.ArraySeq[Any] = {
val array = getArray(vector, i, deserializer)(element.clsTag)
ScalaCollectionUtils.wrap(array)
}
}
} else if (isSubClass(Classes.IMMUTABLE_ARRAY_SEQ, tag)) {
new VectorFieldDeserializer[immutable.ArraySeq[Any], ListVector](v) {
def value(i: Int): immutable.ArraySeq[Any] = {
val array = getArray(vector, i, deserializer)(element.clsTag)
array.asInstanceOf[Array[_]].toImmutableArraySeq
}
}
} else if (isSubClass(Classes.ITERABLE, tag)) {
val companion = ScalaCollectionUtils.getIterableCompanion(tag)
new VectorFieldDeserializer[Iterable[Any], ListVector](v) {
def value(i: Int): Iterable[Any] = {
val builder = companion.newBuilder[Any]
loadListIntoBuilder(vector, i, deserializer, builder)
builder.result()
}
}
} else if (isSubClass(Classes.JLIST, tag)) {
val newInstance = resolveJavaListCreator(tag)
new VectorFieldDeserializer[JList[Any], ListVector](v) {
def value(i: Int): JList[Any] = {
var index = v.getElementStartIndex(i)
val end = v.getElementEndIndex(i)
val list = newInstance(end - index)
while (index < end) {
list.add(deserializer.get(index))
index += 1
}
list
}
}
} else {
throw unsupportedCollectionType(tag.runtimeClass)
}
case (MapEncoder(tag, key, value, _), v: MapVector) =>
val structVector = v.getDataVector.asInstanceOf[StructVector]
val keyDeserializer =
deserializerFor(key, structVector.getChild(MapVector.KEY_NAME), timeZoneId)
val valueDeserializer =
deserializerFor(value, structVector.getChild(MapVector.VALUE_NAME), timeZoneId)
if (isSubClass(Classes.MAP, tag)) {
val companion = ScalaCollectionUtils.getMapCompanion(tag)
new VectorFieldDeserializer[Map[Any, Any], MapVector](v) {
def value(i: Int): Map[Any, Any] = {
val builder = companion.newBuilder[Any, Any]
var index = v.getElementStartIndex(i)
val end = v.getElementEndIndex(i)
builder.sizeHint(end - index)
while (index < end) {
builder += (keyDeserializer.get(index) -> valueDeserializer.get(index))
index += 1
}
builder.result()
}
}
} else if (isSubClass(Classes.JMAP, tag)) {
val newInstance = resolveJavaMapCreator(tag)
new VectorFieldDeserializer[JMap[Any, Any], MapVector](v) {
def value(i: Int): JMap[Any, Any] = {
val map = newInstance()
var index = v.getElementStartIndex(i)
val end = v.getElementEndIndex(i)
while (index < end) {
map.put(keyDeserializer.get(index), valueDeserializer.get(index))
index += 1
}
map
}
}
} else {
throw unsupportedCollectionType(tag.runtimeClass)
}
case (ProductEncoder(tag, fields, outerPointerGetter), StructVectors(struct, vectors)) =>
val outer = outerPointerGetter.map(_()).toSeq
// We should try to make this work with MethodHandles.
val Some(constructor) =
ScalaReflection.findConstructor(
tag.runtimeClass,
outer.map(_.getClass) ++ fields.map(_.enc.clsTag.runtimeClass))
val deserializers = if (isTuple(tag.runtimeClass)) {
fields.zip(vectors).map { case (field, vector) =>
deserializerFor(field.enc, vector, timeZoneId)
}
} else {
val outerDeserializer = outer.map { value =>
new Deserializer[Any] {
override def get(i: Int): Any = value
}
}
val lookup = createFieldLookup(vectors)
outerDeserializer ++ fields.map { field =>
deserializerFor(field.enc, lookup(field.name), timeZoneId)
}
}
new StructFieldSerializer[Any](struct) {
def value(i: Int): Any = {
constructor(deserializers.map(_.get(i).asInstanceOf[AnyRef]))
}
}
case (r @ RowEncoder(fields), StructVectors(struct, vectors)) =>
val lookup = createFieldLookup(vectors)
val deserializers = fields.toArray.map { field =>
deserializerFor(field.enc, lookup(field.name), timeZoneId)
}
new StructFieldSerializer[Any](struct) {
def value(i: Int): Any = {
val values = deserializers.map(_.get(i))
new GenericRowWithSchema(values, r.schema)
}
}
case (VariantEncoder, StructVectors(struct, vectors)) =>
assert(vectors.exists(_.getName == "value"))
assert(
vectors.exists(field =>
field.getName == "metadata" && field.getField.getMetadata
.containsKey("variant") && field.getField.getMetadata.get("variant") == "true"))
val valueDecoder =
deserializerFor(
BinaryEncoder,
vectors
.find(_.getName == "value")
.getOrElse(throw CompilationErrors.columnNotFoundError("value")),
timeZoneId)
val metadataDecoder =
deserializerFor(
BinaryEncoder,
vectors
.find(_.getName == "metadata")
.getOrElse(throw CompilationErrors.columnNotFoundError("metadata")),
timeZoneId)
new StructFieldSerializer[VariantVal](struct) {
def value(i: Int): VariantVal = {
new VariantVal(
valueDecoder.get(i).asInstanceOf[Array[Byte]],
metadataDecoder.get(i).asInstanceOf[Array[Byte]])
}
}
case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) =>
val constructor =
methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit]))
val lookup = createFieldLookup(vectors)
val setters = fields
.filter(_.writeMethod.isDefined)
.map { field =>
val vector = lookup(field.name)
val deserializer = deserializerFor(field.enc, vector, timeZoneId)
val setter = methodLookup.findVirtual(
tag.runtimeClass,
field.writeMethod.get,
MethodType.methodType(classOf[Unit], field.enc.clsTag.runtimeClass))
(bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i))
}
new StructFieldSerializer[Any](struct) {
def value(i: Int): Any = {
val instance = constructor.invoke()
setters.foreach(_(instance, i))
instance
}
}
case (TransformingEncoder(_, encoder, provider, _), v) =>
new Deserializer[Any] {
private[this] val codec = provider()
private[this] val deserializer = deserializerFor(encoder, v, timeZoneId)
override def get(i: Int): Any = codec.decode(deserializer.get(i))
}
case (CalendarIntervalEncoder | _: UDTEncoder[_], _) =>
throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType)
case _ =>
throw new RuntimeException(
s"Unsupported Encoder($encoder)/Vector(${data.getClass}) combination.")
}
}