in sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala [258:417]
private def encoderFor(
tpe: `Type`,
seenTypeSet: Set[`Type`],
path: WalkedTypePath,
isRowEncoderSupported: Boolean): AgnosticEncoder[_] = {
def createIterableEncoder(t: `Type`, fallbackClass: Class[_]): AgnosticEncoder[_] = {
val TypeRef(_, _, Seq(elementType)) = t
val encoder = encoderFor(
elementType,
seenTypeSet,
path.recordArray(getClassNameFromType(elementType)),
isRowEncoderSupported)
val companion = t.dealias.typeSymbol.companion.typeSignature
val targetClass = companion.member(TermName("newBuilder")) match {
case NoSymbol => fallbackClass
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
IterableEncoder(
ClassTag(targetClass),
encoder,
encoder.nullable,
lenientSerialization = false)
}
baseType(tpe) match {
// this must be the first case, since all objects in scala are instances of Null, therefore
// Null type would wrongly match the first of them, which is Option as of now
case t if isSubtype(t, definitions.NullTpe) => NullEncoder
// Primitive encoders
case t if isSubtype(t, definitions.BooleanTpe) => PrimitiveBooleanEncoder
case t if isSubtype(t, definitions.ByteTpe) => PrimitiveByteEncoder
case t if isSubtype(t, definitions.ShortTpe) => PrimitiveShortEncoder
case t if isSubtype(t, definitions.IntTpe) => PrimitiveIntEncoder
case t if isSubtype(t, definitions.LongTpe) => PrimitiveLongEncoder
case t if isSubtype(t, definitions.FloatTpe) => PrimitiveFloatEncoder
case t if isSubtype(t, definitions.DoubleTpe) => PrimitiveDoubleEncoder
case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => BoxedBooleanEncoder
case t if isSubtype(t, localTypeOf[java.lang.Byte]) => BoxedByteEncoder
case t if isSubtype(t, localTypeOf[java.lang.Short]) => BoxedShortEncoder
case t if isSubtype(t, localTypeOf[java.lang.Integer]) => BoxedIntEncoder
case t if isSubtype(t, localTypeOf[java.lang.Long]) => BoxedLongEncoder
case t if isSubtype(t, localTypeOf[java.lang.Float]) => BoxedFloatEncoder
case t if isSubtype(t, localTypeOf[java.lang.Double]) => BoxedDoubleEncoder
case t if isSubtype(t, localTypeOf[Array[Byte]]) => BinaryEncoder
// Enums
case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
JavaEnumEncoder(ClassTag(getClassFromType(t)))
case t if isSubtype(t, localTypeOf[Enumeration#Value]) =>
// package example
// object Foo extends Enumeration {
// type Foo = Value
// val E1, E2 = Value
// }
// the fullName of tpe is example.Foo.Foo, but we need example.Foo so that
// we can call example.Foo.withName to deserialize string to enumeration.
val parent = getClassFromType(t.asInstanceOf[TypeRef].pre)
ScalaEnumEncoder(parent, ClassTag(getClassFromType(t)))
// Leaf encoders
case t if isSubtype(t, localTypeOf[String]) => StringEncoder
case t if isSubtype(t, localTypeOf[Decimal]) => DEFAULT_SPARK_DECIMAL_ENCODER
case t if isSubtype(t, localTypeOf[BigDecimal]) => DEFAULT_SCALA_DECIMAL_ENCODER
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => DEFAULT_JAVA_DECIMAL_ENCODER
case t if isSubtype(t, localTypeOf[BigInt]) => ScalaBigIntEncoder
case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => JavaBigIntEncoder
case t if isSubtype(t, localTypeOf[CalendarInterval]) => CalendarIntervalEncoder
case t if isSubtype(t, localTypeOf[java.time.Duration]) => DayTimeIntervalEncoder
case t if isSubtype(t, localTypeOf[java.time.Period]) => YearMonthIntervalEncoder
case t if isSubtype(t, localTypeOf[java.sql.Date]) => STRICT_DATE_ENCODER
case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => STRICT_LOCAL_DATE_ENCODER
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => STRICT_TIMESTAMP_ENCODER
case t if isSubtype(t, localTypeOf[java.time.Instant]) => STRICT_INSTANT_ENCODER
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => LocalDateTimeEncoder
case t if isSubtype(t, localTypeOf[java.time.LocalTime]) => LocalTimeEncoder
case t if isSubtype(t, localTypeOf[VariantVal]) => VariantEncoder
case t if isSubtype(t, localTypeOf[Row]) => UnboundRowEncoder
// UDT encoders
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t)
.getAnnotation(classOf[SQLUserDefinedType])
.udt()
.getConstructor()
.newInstance()
.asInstanceOf[UserDefinedType[Any]]
val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()
UDTEncoder(udt, udtClass)
case t if UDTRegistration.exists(getClassNameFromType(t)) =>
val udt = UDTRegistration
.getUDTFor(getClassNameFromType(t))
.get
.getConstructor()
.newInstance()
.asInstanceOf[UserDefinedType[Any]]
UDTEncoder(udt, udt.getClass)
// Complex encoders
case t if isSubtype(t, localTypeOf[Option[_]]) =>
val TypeRef(_, _, Seq(optType)) = t
val encoder = encoderFor(
optType,
seenTypeSet,
path.recordOption(getClassNameFromType(optType)),
isRowEncoderSupported)
OptionEncoder(encoder)
case t if isSubtype(t, localTypeOf[Array[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
val encoder = encoderFor(
elementType,
seenTypeSet,
path.recordArray(getClassNameFromType(elementType)),
isRowEncoderSupported)
ArrayEncoder(encoder, encoder.nullable)
case t if isSubtype(t, localTypeOf[scala.collection.Seq[_]]) =>
createIterableEncoder(t, classOf[scala.collection.Seq[_]])
case t if isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
createIterableEncoder(t, classOf[scala.collection.Set[_]])
case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val keyEncoder = encoderFor(
keyType,
seenTypeSet,
path.recordKeyForMap(getClassNameFromType(keyType)),
isRowEncoderSupported)
val valueEncoder = encoderFor(
valueType,
seenTypeSet,
path.recordValueForMap(getClassNameFromType(valueType)),
isRowEncoderSupported)
MapEncoder(ClassTag(getClassFromType(t)), keyEncoder, valueEncoder, valueEncoder.nullable)
case t if definedByConstructorParams(t) =>
if (seenTypeSet.contains(t)) {
throw ExecutionErrors.cannotHaveCircularReferencesInClassError(t.toString)
}
val params = getConstructorParameters(t).map { case (fieldName, fieldType) =>
if (SourceVersion.isKeyword(fieldName) ||
!SourceVersion.isIdentifier(encodeFieldNameToIdentifier(fieldName))) {
throw ExecutionErrors.cannotUseInvalidJavaIdentifierAsFieldNameError(fieldName, path)
}
val encoder = encoderFor(
fieldType,
seenTypeSet + t,
path.recordField(getClassNameFromType(fieldType), fieldName),
isRowEncoderSupported)
EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty)
}
val cls = getClassFromType(t)
ProductEncoder(ClassTag(cls), params, Option(OuterScopes.getOuterScope(cls)))
case _ =>
throw ExecutionErrors.cannotFindEncoderForTypeError(tpe.toString)
}
}