in sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala [313:419]
def apply(expression: Expression): Expression = expression match {
case a @ CreateArray(children, _) if !haveSameType(children.map(_.dataType)) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => a.copy(children.map(castIfNotSameType(_, finalDataType)))
case None => a
}
case c @ Concat(children)
if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
!haveSameType(c.inputTypesForMerging) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType)))
case None => c
}
case aj @ ArrayJoin(arr, d, nr)
if !AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)).
acceptsType(arr.dataType) &&
ArrayType.acceptsType(arr.dataType) =>
val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull
implicitCast(arr, ArrayType(StringType, containsNull)) match {
case Some(castedArr) => ArrayJoin(castedArr, d, nr)
case None => aj
}
case s @ Sequence(_, _, _, timeZoneId)
if !haveSameType(s.coercibleChildren.map(_.dataType)) =>
val types = s.coercibleChildren.map(_.dataType)
findWiderCommonType(types) match {
case Some(widerDataType) => s.castChildrenTo(widerDataType)
case None => s
}
case m @ MapConcat(children)
if children.forall(c => MapType.acceptsType(c.dataType)) &&
!haveSameType(m.inputTypesForMerging) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType)))
case None => m
}
case m @ CreateMap(children, _)
if m.keys.length == m.values.length &&
(!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) =>
val keyTypes = m.keys.map(_.dataType)
val newKeys = findWiderCommonType(keyTypes) match {
case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType))
case None => m.keys
}
val valueTypes = m.values.map(_.dataType)
val newValues = findWiderCommonType(valueTypes) match {
case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType))
case None => m.values
}
m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
// Hive lets you do aggregation of timestamps... for some reason
case Sum(e @ TimestampTypeExpression(), _) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampTypeExpression(), _) => Average(Cast(e, DoubleType))
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) =>
val types = es.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) =>
Coalesce(es.map(castIfNotSameType(_, finalDataType)))
case None =>
c
}
// When finding wider type for `Greatest` and `Least`, we should handle decimal types even if
// we need to truncate, but we should not promote one side to string if the other side is
// string.g
case g @ Greatest(children) if !haveSameType(g.inputTypesForMerging) =>
val types = children.map(_.dataType)
findWiderTypeWithoutStringPromotion(types) match {
case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType)))
case None => g
}
case l @ Least(children) if !haveSameType(l.inputTypesForMerging) =>
val types = children.map(_.dataType)
findWiderTypeWithoutStringPromotion(types) match {
case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType)))
case None => l
}
case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType =>
NaNvl(l, Cast(r, DoubleType))
case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
NaNvl(Cast(l, DoubleType), r)
case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType))
case r: RandStr if r.length.dataType != IntegerType =>
implicitCast(r.length, IntegerType).map { casted =>
r.copy(length = casted)
}.getOrElse(r)
case other => other
}