private[arrow] def deserializerFor()

in sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala [90:405]


  private[arrow] def deserializerFor(
      encoder: AgnosticEncoder[_],
      data: AnyRef,
      timeZoneId: String): Deserializer[Any] = {
    (encoder, data) match {
      case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[Boolean](encoder, v, timeZoneId) {
          override def value(i: Int): Boolean = reader.getBoolean(i)
        }
      case (PrimitiveByteEncoder | BoxedByteEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[Byte](encoder, v, timeZoneId) {
          override def value(i: Int): Byte = reader.getByte(i)
        }
      case (PrimitiveShortEncoder | BoxedShortEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[Short](encoder, v, timeZoneId) {
          override def value(i: Int): Short = reader.getShort(i)
        }
      case (PrimitiveIntEncoder | BoxedIntEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[Int](encoder, v, timeZoneId) {
          override def value(i: Int): Int = reader.getInt(i)
        }
      case (PrimitiveLongEncoder | BoxedLongEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[Long](encoder, v, timeZoneId) {
          override def value(i: Int): Long = reader.getLong(i)
        }
      case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[Float](encoder, v, timeZoneId) {
          override def value(i: Int): Float = reader.getFloat(i)
        }
      case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[Double](encoder, v, timeZoneId) {
          override def value(i: Int): Double = reader.getDouble(i)
        }
      case (NullEncoder, _: FieldVector) =>
        new Deserializer[Any] {
          def get(i: Int): Any = null
        }
      case (StringEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[String](encoder, v, timeZoneId) {
          override def value(i: Int): String = reader.getString(i)
        }
      case (JavaEnumEncoder(tag), v: FieldVector) =>
        // It would be nice if we can get Enum.valueOf working...
        val valueOf = methodLookup.findStatic(
          tag.runtimeClass,
          "valueOf",
          MethodType.methodType(tag.runtimeClass, classOf[String]))
        new LeafFieldDeserializer[Enum[_]](encoder, v, timeZoneId) {
          override def value(i: Int): Enum[_] = {
            valueOf.invoke(reader.getString(i)).asInstanceOf[Enum[_]]
          }
        }
      case (ScalaEnumEncoder(parent, _), v: FieldVector) =>
        val mirror = scala.reflect.runtime.currentMirror
        val module = mirror.classSymbol(parent).module.asModule
        val enumeration = mirror.reflectModule(module).instance.asInstanceOf[Enumeration]
        new LeafFieldDeserializer[Enumeration#Value](encoder, v, timeZoneId) {
          override def value(i: Int): Enumeration#Value = {
            enumeration.withName(reader.getString(i))
          }
        }
      case (BinaryEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[Array[Byte]](encoder, v, timeZoneId) {
          override def value(i: Int): Array[Byte] = reader.getBytes(i)
        }
      case (SparkDecimalEncoder(_), v: FieldVector) =>
        new LeafFieldDeserializer[Decimal](encoder, v, timeZoneId) {
          override def value(i: Int): Decimal = reader.getDecimal(i)
        }
      case (ScalaDecimalEncoder(_), v: FieldVector) =>
        new LeafFieldDeserializer[BigDecimal](encoder, v, timeZoneId) {
          override def value(i: Int): BigDecimal = reader.getScalaDecimal(i)
        }
      case (JavaDecimalEncoder(_, _), v: FieldVector) =>
        new LeafFieldDeserializer[JBigDecimal](encoder, v, timeZoneId) {
          override def value(i: Int): JBigDecimal = reader.getJavaDecimal(i)
        }
      case (ScalaBigIntEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[BigInt](encoder, v, timeZoneId) {
          override def value(i: Int): BigInt = reader.getScalaBigInt(i)
        }
      case (JavaBigIntEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[JBigInteger](encoder, v, timeZoneId) {
          override def value(i: Int): JBigInteger = reader.getJavaBigInt(i)
        }
      case (DayTimeIntervalEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[Duration](encoder, v, timeZoneId) {
          override def value(i: Int): Duration = reader.getDuration(i)
        }
      case (YearMonthIntervalEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[Period](encoder, v, timeZoneId) {
          override def value(i: Int): Period = reader.getPeriod(i)
        }
      case (DateEncoder(_), v: FieldVector) =>
        new LeafFieldDeserializer[java.sql.Date](encoder, v, timeZoneId) {
          override def value(i: Int): java.sql.Date = reader.getDate(i)
        }
      case (LocalDateEncoder(_), v: FieldVector) =>
        new LeafFieldDeserializer[LocalDate](encoder, v, timeZoneId) {
          override def value(i: Int): LocalDate = reader.getLocalDate(i)
        }
      case (TimestampEncoder(_), v: FieldVector) =>
        new LeafFieldDeserializer[java.sql.Timestamp](encoder, v, timeZoneId) {
          override def value(i: Int): java.sql.Timestamp = reader.getTimestamp(i)
        }
      case (InstantEncoder(_), v: FieldVector) =>
        new LeafFieldDeserializer[Instant](encoder, v, timeZoneId) {
          override def value(i: Int): Instant = reader.getInstant(i)
        }
      case (LocalDateTimeEncoder, v: FieldVector) =>
        new LeafFieldDeserializer[LocalDateTime](encoder, v, timeZoneId) {
          override def value(i: Int): LocalDateTime = reader.getLocalDateTime(i)
        }

      case (OptionEncoder(value), v) =>
        val deserializer = deserializerFor(value, v, timeZoneId)
        new Deserializer[Any] {
          override def get(i: Int): Any = Option(deserializer.get(i))
        }

      case (ArrayEncoder(element, _), v: ListVector) =>
        val deserializer = deserializerFor(element, v.getDataVector, timeZoneId)
        new VectorFieldDeserializer[AnyRef, ListVector](v) {
          def value(i: Int): AnyRef = getArray(vector, i, deserializer)(element.clsTag)
        }

      case (IterableEncoder(tag, element, _, _), v: ListVector) =>
        val deserializer = deserializerFor(element, v.getDataVector, timeZoneId)
        if (isSubClass(Classes.MUTABLE_ARRAY_SEQ, tag)) {
          // mutable ArraySeq is a bit special because we need to use an array of the element type.
          // Some parts of our codebase (unfortunately) rely on this for type inference on results.
          new VectorFieldDeserializer[mutable.ArraySeq[Any], ListVector](v) {
            def value(i: Int): mutable.ArraySeq[Any] = {
              val array = getArray(vector, i, deserializer)(element.clsTag)
              ScalaCollectionUtils.wrap(array)
            }
          }
        } else if (isSubClass(Classes.IMMUTABLE_ARRAY_SEQ, tag)) {
          new VectorFieldDeserializer[immutable.ArraySeq[Any], ListVector](v) {
            def value(i: Int): immutable.ArraySeq[Any] = {
              val array = getArray(vector, i, deserializer)(element.clsTag)
              array.asInstanceOf[Array[_]].toImmutableArraySeq
            }
          }
        } else if (isSubClass(Classes.ITERABLE, tag)) {
          val companion = ScalaCollectionUtils.getIterableCompanion(tag)
          new VectorFieldDeserializer[Iterable[Any], ListVector](v) {
            def value(i: Int): Iterable[Any] = {
              val builder = companion.newBuilder[Any]
              loadListIntoBuilder(vector, i, deserializer, builder)
              builder.result()
            }
          }
        } else if (isSubClass(Classes.JLIST, tag)) {
          val newInstance = resolveJavaListCreator(tag)
          new VectorFieldDeserializer[JList[Any], ListVector](v) {
            def value(i: Int): JList[Any] = {
              var index = v.getElementStartIndex(i)
              val end = v.getElementEndIndex(i)
              val list = newInstance(end - index)
              while (index < end) {
                list.add(deserializer.get(index))
                index += 1
              }
              list
            }
          }
        } else {
          throw unsupportedCollectionType(tag.runtimeClass)
        }

      case (MapEncoder(tag, key, value, _), v: MapVector) =>
        val structVector = v.getDataVector.asInstanceOf[StructVector]
        val keyDeserializer =
          deserializerFor(key, structVector.getChild(MapVector.KEY_NAME), timeZoneId)
        val valueDeserializer =
          deserializerFor(value, structVector.getChild(MapVector.VALUE_NAME), timeZoneId)
        if (isSubClass(Classes.MAP, tag)) {
          val companion = ScalaCollectionUtils.getMapCompanion(tag)
          new VectorFieldDeserializer[Map[Any, Any], MapVector](v) {
            def value(i: Int): Map[Any, Any] = {
              val builder = companion.newBuilder[Any, Any]
              var index = v.getElementStartIndex(i)
              val end = v.getElementEndIndex(i)
              builder.sizeHint(end - index)
              while (index < end) {
                builder += (keyDeserializer.get(index) -> valueDeserializer.get(index))
                index += 1
              }
              builder.result()
            }
          }
        } else if (isSubClass(Classes.JMAP, tag)) {
          val newInstance = resolveJavaMapCreator(tag)
          new VectorFieldDeserializer[JMap[Any, Any], MapVector](v) {
            def value(i: Int): JMap[Any, Any] = {
              val map = newInstance()
              var index = v.getElementStartIndex(i)
              val end = v.getElementEndIndex(i)
              while (index < end) {
                map.put(keyDeserializer.get(index), valueDeserializer.get(index))
                index += 1
              }
              map
            }
          }
        } else {
          throw unsupportedCollectionType(tag.runtimeClass)
        }

      case (ProductEncoder(tag, fields, outerPointerGetter), StructVectors(struct, vectors)) =>
        val outer = outerPointerGetter.map(_()).toSeq
        // We should try to make this work with MethodHandles.
        val Some(constructor) =
          ScalaReflection.findConstructor(
            tag.runtimeClass,
            outer.map(_.getClass) ++ fields.map(_.enc.clsTag.runtimeClass))
        val deserializers = if (isTuple(tag.runtimeClass)) {
          fields.zip(vectors).map { case (field, vector) =>
            deserializerFor(field.enc, vector, timeZoneId)
          }
        } else {
          val outerDeserializer = outer.map { value =>
            new Deserializer[Any] {
              override def get(i: Int): Any = value
            }
          }
          val lookup = createFieldLookup(vectors)
          outerDeserializer ++ fields.map { field =>
            deserializerFor(field.enc, lookup(field.name), timeZoneId)
          }
        }
        new StructFieldSerializer[Any](struct) {
          def value(i: Int): Any = {
            constructor(deserializers.map(_.get(i).asInstanceOf[AnyRef]))
          }
        }

      case (r @ RowEncoder(fields), StructVectors(struct, vectors)) =>
        val lookup = createFieldLookup(vectors)
        val deserializers = fields.toArray.map { field =>
          deserializerFor(field.enc, lookup(field.name), timeZoneId)
        }
        new StructFieldSerializer[Any](struct) {
          def value(i: Int): Any = {
            val values = deserializers.map(_.get(i))
            new GenericRowWithSchema(values, r.schema)
          }
        }

      case (VariantEncoder, StructVectors(struct, vectors)) =>
        assert(vectors.exists(_.getName == "value"))
        assert(
          vectors.exists(field =>
            field.getName == "metadata" && field.getField.getMetadata
              .containsKey("variant") && field.getField.getMetadata.get("variant") == "true"))
        val valueDecoder =
          deserializerFor(
            BinaryEncoder,
            vectors
              .find(_.getName == "value")
              .getOrElse(throw CompilationErrors.columnNotFoundError("value")),
            timeZoneId)
        val metadataDecoder =
          deserializerFor(
            BinaryEncoder,
            vectors
              .find(_.getName == "metadata")
              .getOrElse(throw CompilationErrors.columnNotFoundError("metadata")),
            timeZoneId)
        new StructFieldSerializer[VariantVal](struct) {
          def value(i: Int): VariantVal = {
            new VariantVal(
              valueDecoder.get(i).asInstanceOf[Array[Byte]],
              metadataDecoder.get(i).asInstanceOf[Array[Byte]])
          }
        }

      case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) =>
        val constructor =
          methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit]))
        val lookup = createFieldLookup(vectors)
        val setters = fields
          .filter(_.writeMethod.isDefined)
          .map { field =>
            val vector = lookup(field.name)
            val deserializer = deserializerFor(field.enc, vector, timeZoneId)
            val setter = methodLookup.findVirtual(
              tag.runtimeClass,
              field.writeMethod.get,
              MethodType.methodType(classOf[Unit], field.enc.clsTag.runtimeClass))
            (bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i))
          }
        new StructFieldSerializer[Any](struct) {
          def value(i: Int): Any = {
            val instance = constructor.invoke()
            setters.foreach(_(instance, i))
            instance
          }
        }

      case (TransformingEncoder(_, encoder, provider, _), v) =>
        new Deserializer[Any] {
          private[this] val codec = provider()
          private[this] val deserializer = deserializerFor(encoder, v, timeZoneId)
          override def get(i: Int): Any = codec.decode(deserializer.get(i))
        }

      case (CalendarIntervalEncoder | _: UDTEncoder[_], _) =>
        throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType)

      case _ =>
        throw new RuntimeException(
          s"Unsupported Encoder($encoder)/Vector(${data.getClass}) combination.")
    }
  }