def apply()

in sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala [313:419]


    def apply(expression: Expression): Expression = expression match {
      case a @ CreateArray(children, _) if !haveSameType(children.map(_.dataType)) =>
        val types = children.map(_.dataType)
        findWiderCommonType(types) match {
          case Some(finalDataType) => a.copy(children.map(castIfNotSameType(_, finalDataType)))
          case None => a
        }

      case c @ Concat(children)
          if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
          !haveSameType(c.inputTypesForMerging) =>
        val types = children.map(_.dataType)
        findWiderCommonType(types) match {
          case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType)))
          case None => c
        }

      case aj @ ArrayJoin(arr, d, nr)
          if !AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)).
            acceptsType(arr.dataType) &&
          ArrayType.acceptsType(arr.dataType) =>
        val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull
        implicitCast(arr, ArrayType(StringType, containsNull)) match {
          case Some(castedArr) => ArrayJoin(castedArr, d, nr)
          case None => aj
        }

      case s @ Sequence(_, _, _, timeZoneId)
          if !haveSameType(s.coercibleChildren.map(_.dataType)) =>
        val types = s.coercibleChildren.map(_.dataType)
        findWiderCommonType(types) match {
          case Some(widerDataType) => s.castChildrenTo(widerDataType)
          case None => s
        }

      case m @ MapConcat(children)
          if children.forall(c => MapType.acceptsType(c.dataType)) &&
          !haveSameType(m.inputTypesForMerging) =>
        val types = children.map(_.dataType)
        findWiderCommonType(types) match {
          case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType)))
          case None => m
        }

      case m @ CreateMap(children, _)
          if m.keys.length == m.values.length &&
          (!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) =>
        val keyTypes = m.keys.map(_.dataType)
        val newKeys = findWiderCommonType(keyTypes) match {
          case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType))
          case None => m.keys
        }

        val valueTypes = m.values.map(_.dataType)
        val newValues = findWiderCommonType(valueTypes) match {
          case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType))
          case None => m.values
        }

        m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })

      // Hive lets you do aggregation of timestamps... for some reason
      case Sum(e @ TimestampTypeExpression(), _) => Sum(Cast(e, DoubleType))
      case Average(e @ TimestampTypeExpression(), _) => Average(Cast(e, DoubleType))

      // Coalesce should return the first non-null value, which could be any column
      // from the list. So we need to make sure the return type is deterministic and
      // compatible with every child column.
      case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) =>
        val types = es.map(_.dataType)
        findWiderCommonType(types) match {
          case Some(finalDataType) =>
            Coalesce(es.map(castIfNotSameType(_, finalDataType)))
          case None =>
            c
        }

      // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if
      // we need to truncate, but we should not promote one side to string if the other side is
      // string.g
      case g @ Greatest(children) if !haveSameType(g.inputTypesForMerging) =>
        val types = children.map(_.dataType)
        findWiderTypeWithoutStringPromotion(types) match {
          case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType)))
          case None => g
        }

      case l @ Least(children) if !haveSameType(l.inputTypesForMerging) =>
        val types = children.map(_.dataType)
        findWiderTypeWithoutStringPromotion(types) match {
          case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType)))
          case None => l
        }

      case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType =>
        NaNvl(l, Cast(r, DoubleType))
      case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
        NaNvl(Cast(l, DoubleType), r)
      case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType))

      case r: RandStr if r.length.dataType != IntegerType =>
        implicitCast(r.length, IntegerType).map { casted =>
          r.copy(length = casted)
        }.getOrElse(r)

      case other => other
    }