private def generateExpression()

in sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala [77:340]


  private def generateExpression(
      expr: Expression, isPredicate: Boolean = false): Option[V2Expression] = expr match {
    case Literal(true, BooleanType) => Some(new AlwaysTrue())
    case Literal(false, BooleanType) => Some(new AlwaysFalse())
    case Literal(value, dataType) => Some(LiteralValue(value, dataType))
    case col @ ColumnOrField(nameParts) =>
      val ref = FieldReference(nameParts)
      if (isPredicate && col.dataType.isInstanceOf[BooleanType]) {
        Some(new V2Predicate("=", Array(ref, LiteralValue(true, BooleanType))))
      } else {
        Some(ref)
      }
    case InSet(child, hset) =>
      generateExpression(child).map { v =>
        val children =
          (v +: hset.toSeq.map(elem => LiteralValue(elem, child.dataType))).toArray[V2Expression]
        new V2Predicate("IN", children)
      }
    // Because we only convert In to InSet in Optimizer when there are more than certain
    // items. So it is possible we still get an In expression here that needs to be pushed
    // down.
    case In(value, list) =>
      val v = generateExpression(value)
      val listExpressions = list.flatMap(generateExpression(_))
      if (v.isDefined && list.length == listExpressions.length) {
        val children = (v.get +: listExpressions).toArray[V2Expression]
        // The children looks like [expr, value1, ..., valueN]
        Some(new V2Predicate("IN", children))
      } else {
        None
      }
    case IsNull(col) => generateExpression(col)
      .map(c => new V2Predicate("IS_NULL", Array[V2Expression](c)))
    case IsNotNull(col) => generateExpression(col)
      .map(c => new V2Predicate("IS_NOT_NULL", Array[V2Expression](c)))
    case p: StringPredicate =>
      val left = generateExpression(p.left)
      val right = generateExpression(p.right)
      if (left.isDefined && right.isDefined) {
        val name = p match {
          case _: StartsWith => "STARTS_WITH"
          case _: EndsWith => "ENDS_WITH"
          case _: Contains => "CONTAINS"
        }
        Some(new V2Predicate(name, Array[V2Expression](left.get, right.get)))
      } else {
        None
      }
    case Cast(child, dataType, _, evalMode)
        if evalMode == EvalMode.ANSI || Cast.canUpCast(child.dataType, dataType) =>
      generateExpression(child).map(v => new V2Cast(v, child.dataType, dataType))
    case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) =>
      generateAggregateFunc(aggregateFunction, isDistinct)
    case Abs(_, true) => generateExpressionWithName("ABS", expr, isPredicate)
    case _: Coalesce => generateExpressionWithName("COALESCE", expr, isPredicate)
    case _: Greatest => generateExpressionWithName("GREATEST", expr, isPredicate)
    case _: Least => generateExpressionWithName("LEAST", expr, isPredicate)
    case Rand(_, hideSeed) =>
      if (hideSeed) {
        Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression]))
      } else {
        generateExpressionWithName("RAND", expr, isPredicate)
      }
    case _: Logarithm => generateExpressionWithName("LOG", expr, isPredicate)
    case _: Log10 => generateExpressionWithName("LOG10", expr, isPredicate)
    case _: Log2 => generateExpressionWithName("LOG2", expr, isPredicate)
    case _: Log => generateExpressionWithName("LN", expr, isPredicate)
    case _: Exp => generateExpressionWithName("EXP", expr, isPredicate)
    case _: Pow => generateExpressionWithName("POWER", expr, isPredicate)
    case _: Sqrt => generateExpressionWithName("SQRT", expr, isPredicate)
    case _: Floor => generateExpressionWithName("FLOOR", expr, isPredicate)
    case _: Ceil => generateExpressionWithName("CEIL", expr, isPredicate)
    case _: Round => generateExpressionWithName("ROUND", expr, isPredicate)
    case _: Sin => generateExpressionWithName("SIN", expr, isPredicate)
    case _: Sinh => generateExpressionWithName("SINH", expr, isPredicate)
    case _: Cos => generateExpressionWithName("COS", expr, isPredicate)
    case _: Cosh => generateExpressionWithName("COSH", expr, isPredicate)
    case _: Tan => generateExpressionWithName("TAN", expr, isPredicate)
    case _: Tanh => generateExpressionWithName("TANH", expr, isPredicate)
    case _: Cot => generateExpressionWithName("COT", expr, isPredicate)
    case _: Asin => generateExpressionWithName("ASIN", expr, isPredicate)
    case _: Asinh => generateExpressionWithName("ASINH", expr, isPredicate)
    case _: Acos => generateExpressionWithName("ACOS", expr, isPredicate)
    case _: Acosh => generateExpressionWithName("ACOSH", expr, isPredicate)
    case _: Atan => generateExpressionWithName("ATAN", expr, isPredicate)
    case _: Atanh => generateExpressionWithName("ATANH", expr, isPredicate)
    case _: Atan2 => generateExpressionWithName("ATAN2", expr, isPredicate)
    case _: Cbrt => generateExpressionWithName("CBRT", expr, isPredicate)
    case _: ToDegrees => generateExpressionWithName("DEGREES", expr, isPredicate)
    case _: ToRadians => generateExpressionWithName("RADIANS", expr, isPredicate)
    case _: Signum => generateExpressionWithName("SIGN", expr, isPredicate)
    case _: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", expr, isPredicate)
    case and: And =>
      // AND expects predicate
      val l = generateExpression(and.left, true)
      val r = generateExpression(and.right, true)
      if (l.isDefined && r.isDefined) {
        assert(l.get.isInstanceOf[V2Predicate] && r.get.isInstanceOf[V2Predicate])
        Some(new V2And(l.get.asInstanceOf[V2Predicate], r.get.asInstanceOf[V2Predicate]))
      } else {
        None
      }
    case or: Or =>
      // OR expects predicate
      val l = generateExpression(or.left, true)
      val r = generateExpression(or.right, true)
      if (l.isDefined && r.isDefined) {
        assert(l.get.isInstanceOf[V2Predicate] && r.get.isInstanceOf[V2Predicate])
        Some(new V2Or(l.get.asInstanceOf[V2Predicate], r.get.asInstanceOf[V2Predicate]))
      } else {
        None
      }
    case b: BinaryOperator if canTranslate(b) =>
      val l = generateExpression(b.left)
      val r = generateExpression(b.right)
      if (l.isDefined && r.isDefined) {
        b match {
          case _: BinaryComparison if l.get.isInstanceOf[LiteralValue[_]] &&
              r.get.isInstanceOf[FieldReference] =>
            Some(new V2Predicate(flipComparisonOperatorName(b.sqlOperator),
              Array[V2Expression](r.get, l.get)))
          case _: Predicate =>
            Some(new V2Predicate(b.sqlOperator, Array[V2Expression](l.get, r.get)))
          case _ =>
            Some(new GeneralScalarExpression(b.sqlOperator, Array[V2Expression](l.get, r.get)))
        }
      } else {
        None
      }
    case Not(eq: EqualTo) =>
      val left = generateExpression(eq.left)
      val right = generateExpression(eq.right)
      if (left.isDefined && right.isDefined) {
        Some(new V2Predicate("<>", Array[V2Expression](left.get, right.get)))
      } else {
        None
      }
    case Not(child) => generateExpression(child, true) // NOT expects predicate
      .map { v =>
        assert(v.isInstanceOf[V2Predicate])
        new V2Not(v.asInstanceOf[V2Predicate])
      }
    case UnaryMinus(_, true) => generateExpressionWithName("-", expr, isPredicate)
    case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate)
    case caseWhen @ CaseWhen(branches, elseValue) =>
      val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
      val values = branches.map(_._2).flatMap(generateExpression(_, isPredicate))
      val elseExprOpt = elseValue.flatMap(generateExpression(_, isPredicate))
      if (conditions.length == branches.length && values.length == branches.length &&
          elseExprOpt.size == elseValue.size) {
        val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
          Seq[V2Expression](c, v)
        }
        val children = (branchExpressions ++ elseExprOpt).toArray[V2Expression]
        // The children looks like [condition1, value1, ..., conditionN, valueN (, elseValue)]
        if (isPredicate && caseWhen.dataType.isInstanceOf[BooleanType]) {
          Some(new V2Predicate("CASE_WHEN", children))
        } else {
          Some(new GeneralScalarExpression("CASE_WHEN", children))
        }
      } else {
        None
      }
    case _: If => generateExpressionWithName("CASE_WHEN", expr, isPredicate)
    case substring: Substring =>
      val children = if (substring.len == Literal(Integer.MAX_VALUE)) {
        Seq(substring.str, substring.pos)
      } else {
        substring.children
      }
      generateExpressionWithNameByChildren("SUBSTRING", children, substring.dataType, isPredicate)
    case _: Upper => generateExpressionWithName("UPPER", expr, isPredicate)
    case _: Lower => generateExpressionWithName("LOWER", expr, isPredicate)
    case BitLength(child) if child.dataType.isInstanceOf[StringType] =>
      generateExpressionWithName("BIT_LENGTH", expr, isPredicate)
    case Length(child) if child.dataType.isInstanceOf[StringType] =>
      generateExpressionWithName("CHAR_LENGTH", expr, isPredicate)
    case _: Concat => generateExpressionWithName("CONCAT", expr, isPredicate)
    case _: StringTranslate => generateExpressionWithName("TRANSLATE", expr, isPredicate)
    case _: StringTrim => generateExpressionWithName("TRIM", expr, isPredicate)
    case _: StringTrimLeft => generateExpressionWithName("LTRIM", expr, isPredicate)
    case _: StringTrimRight => generateExpressionWithName("RTRIM", expr, isPredicate)
    case overlay: Overlay =>
      val children = if (overlay.len == Literal(-1)) {
        Seq(overlay.input, overlay.replace, overlay.pos)
      } else {
        overlay.children
      }
      generateExpressionWithNameByChildren("OVERLAY", children, overlay.dataType, isPredicate)
    case _: DateAdd => generateExpressionWithName("DATE_ADD", expr, isPredicate)
    case _: DateDiff => generateExpressionWithName("DATE_DIFF", expr, isPredicate)
    case _: TruncDate => generateExpressionWithName("TRUNC", expr, isPredicate)
    case Second(child, _) =>
      generateExpression(child).map(v => new V2Extract("SECOND", v))
    case Minute(child, _) =>
      generateExpression(child).map(v => new V2Extract("MINUTE", v))
    case Hour(child, _) =>
      generateExpression(child).map(v => new V2Extract("HOUR", v))
    case Month(child) =>
      generateExpression(child).map(v => new V2Extract("MONTH", v))
    case Quarter(child) =>
      generateExpression(child).map(v => new V2Extract("QUARTER", v))
    case Year(child) =>
      generateExpression(child).map(v => new V2Extract("YEAR", v))
    // DayOfWeek uses Sunday = 1, Monday = 2, ... and ISO standard is Monday = 1, ...,
    // so we use the formula ((ISO_standard % 7) + 1) to do translation.
    case DayOfWeek(child) =>
      generateExpression(child).map(v => new GeneralScalarExpression("+",
        Array[V2Expression](new GeneralScalarExpression("%",
          Array[V2Expression](new V2Extract("DAY_OF_WEEK", v), LiteralValue(7, IntegerType))),
          LiteralValue(1, IntegerType))))
    // WeekDay uses Monday = 0, Tuesday = 1, ... and ISO standard is Monday = 1, ...,
    // so we use the formula (ISO_standard - 1) to do translation.
    case WeekDay(child) =>
      generateExpression(child).map(v => new GeneralScalarExpression("-",
        Array[V2Expression](new V2Extract("DAY_OF_WEEK", v), LiteralValue(1, IntegerType))))
    case DayOfMonth(child) =>
      generateExpression(child).map(v => new V2Extract("DAY", v))
    case DayOfYear(child) =>
      generateExpression(child).map(v => new V2Extract("DAY_OF_YEAR", v))
    case WeekOfYear(child) =>
      generateExpression(child).map(v => new V2Extract("WEEK", v))
    case YearOfWeek(child) =>
      generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v))
    case _: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", expr, isPredicate)
    case _: AesDecrypt => generateExpressionWithName("AES_DECRYPT", expr, isPredicate)
    case _: Crc32 => generateExpressionWithName("CRC32", expr, isPredicate)
    case _: Md5 => generateExpressionWithName("MD5", expr, isPredicate)
    case _: Sha1 => generateExpressionWithName("SHA1", expr, isPredicate)
    case _: Sha2 => generateExpressionWithName("SHA2", expr, isPredicate)
    case _: StringLPad => generateExpressionWithName("LPAD", expr, isPredicate)
    case _: StringRPad => generateExpressionWithName("RPAD", expr, isPredicate)
    // TODO supports other expressions
    case ApplyFunctionExpression(function, children) =>
      val childrenExpressions = children.flatMap(generateExpression(_))
      if (childrenExpressions.length == children.length) {
        Some(new UserDefinedScalarFunc(
          function.name(), function.canonicalName(), childrenExpressions.toArray[V2Expression]))
      } else {
        None
      }
    case Invoke(Literal(obj, _), functionName, _, arguments, _, _, _, _) =>
      obj match {
        case function: ScalarFunction[_] if ScalarFunction.MAGIC_METHOD_NAME == functionName =>
          val argumentExpressions = arguments.flatMap(generateExpression(_))
          if (argumentExpressions.length == arguments.length) {
            Some(new UserDefinedScalarFunc(
              function.name(), function.canonicalName(), argumentExpressions.toArray[V2Expression]))
          } else {
            None
          }
        case _ =>
          None
      }
    case StaticInvoke(_, _, _, arguments, _, _, _, _, Some(scalarFunc)) =>
      val argumentExpressions = arguments.flatMap(generateExpression(_))
      if (argumentExpressions.length == arguments.length) {
        Some(new UserDefinedScalarFunc(
          scalarFunc.name(), scalarFunc.canonicalName(), argumentExpressions.toArray[V2Expression]))
      } else {
        None
      }
    case _ => None
  }