def checkIfCastCanBeRemove()

in integration/spark/src/main/scala/org/apache/spark/sql/execution/CastExpressionOptimization.scala [158:260]


  def checkIfCastCanBeRemove(expr: SparkExpression): Option[Expression] = {

    def checkBinaryExpression(attributeType: SparkDataType,
        value: Any,
        valueType: SparkDataType,
        nonEqual: Boolean = true): Option[Expression] = {
      attributeType match {
        case _: DateType | _: TimestampType if valueType.sameType(StringType) =>
          val filter = updateFilterForTimeStamp(value, expr, attributeType)
          if (nonEqual) {
            updateFilterForNonEqualTimeStamp(value, expr, filter)
          } else {
            filter
          }
        case _: IntegerType if valueType.sameType(DoubleType) =>
          updateFilterForInt(value, expr, attributeType)
        case _: ShortType if valueType.sameType(IntegerType) =>
          updateFilterForShort(value, expr, attributeType)
        case arr: ArrayType if !nonEqual =>
          checkBinaryExpression(arr.elementType, value, valueType, nonEqual)
        case _ => Some(transformExpression(expr))
      }
    }

    def checkInValueList(attributeName: String,
        list: Seq[SparkExpression],
        newList: Seq[SparkExpression],
        dt: SparkDataType,
        hasNot: Boolean = true): Option[Expression] = {
      if (!newList.equals(list)) {
        val hSet = list.map(e => e.eval(EmptyRow))
        if (hasNot) {
          if (hSet.contains(null)) {
            Some(new FalseExpression(translateColumn(attributeName, dt)))
          } else {
            Some(new NotInExpression(translateColumn(attributeName, dt),
              new ListExpression(convertToJavaList(
                hSet.map(f => translateLiteral(f, dt)).toList))))
          }
        } else {
          if (hSet.length == 1 && hSet.head == null) {
            Some(new FalseExpression(translateColumn(attributeName, dt)))
          } else {
            Some(new InExpression(translateColumn(attributeName, dt),
              new ListExpression(convertToJavaList(hSet
                .filterNot(_ == null)
                .map(filterValues => translateLiteral(filterValues, dt))
                .toList))))
          }
        }
      } else {
        Some(transformExpression(expr))
      }
    }

    def checkInExpression(attribute: Attribute,
        list: Seq[SparkExpression],
        hasNot: Boolean = true): Option[Expression] = {
      attribute.dataType match {
        case _: DateType | _: TimestampType if list.head.dataType.sameType(StringType) =>
          checkInValueList(attribute.name, list, typeCastStringToLongList(list, attribute.dataType),
            attribute.dataType, hasNot)
        case _: IntegerType if list.head.dataType.sameType(DoubleType) =>
          checkInValueList(attribute.name, list, typeCastDoubleToIntList(list), attribute.dataType,
            hasNot)
        case _: ShortType if list.head.dataType.sameType(IntegerType) =>
          checkInValueList(attribute.name, list, typeCastIntToShortList(list), attribute.dataType,
            hasNot)
        case _ => Some(transformExpression(expr))
      }
    }

    expr match {
      case EqualTo(Cast(a: Attribute, _), Literal(v, t)) =>
        checkBinaryExpression(a.dataType, v, t, false)
      case EqualTo(Literal(v, t), Cast(a: Attribute, _)) =>
        checkBinaryExpression(a.dataType, v, t, false)
      case Not(EqualTo(Cast(a: Attribute, _), Literal(v, t))) =>
        checkBinaryExpression(a.dataType, v, t, false)
      case Not(EqualTo(Literal(v, t), Cast(a: Attribute, _))) =>
        checkBinaryExpression(a.dataType, v, t, false)
      case Not(In(Cast(a: Attribute, _), list)) =>
        checkInExpression(a, list)
      case In(Cast(a: Attribute, _), list) =>
        checkInExpression(a, list, false)
      case GreaterThan(Cast(a: Attribute, _), Literal(v, t)) =>
        checkBinaryExpression(a.dataType, v, t)
      case GreaterThan(Literal(v, t), Cast(a: Attribute, _)) =>
        checkBinaryExpression(a.dataType, v, t)
      case LessThan(Cast(a: Attribute, _), Literal(v, t)) =>
        checkBinaryExpression(a.dataType, v, t)
      case LessThan(Literal(v, t), Cast(a: Attribute, _)) =>
        checkBinaryExpression(a.dataType, v, t)
      case GreaterThanOrEqual(Cast(a: Attribute, _), Literal(v, t)) =>
        checkBinaryExpression(a.dataType, v, t)
      case GreaterThanOrEqual(Literal(v, t), Cast(a: Attribute, _)) =>
        checkBinaryExpression(a.dataType, v, t)
      case LessThanOrEqual(Cast(a: Attribute, _), Literal(v, t)) =>
        checkBinaryExpression(a.dataType, v, t)
      case LessThanOrEqual(Literal(v, t), Cast(a: Attribute, _)) =>
        checkBinaryExpression(a.dataType, v, t)
    }
  }