in integration/spark/src/main/scala/org/apache/spark/sql/execution/CastExpressionOptimization.scala [158:260]
def checkIfCastCanBeRemove(expr: SparkExpression): Option[Expression] = {
def checkBinaryExpression(attributeType: SparkDataType,
value: Any,
valueType: SparkDataType,
nonEqual: Boolean = true): Option[Expression] = {
attributeType match {
case _: DateType | _: TimestampType if valueType.sameType(StringType) =>
val filter = updateFilterForTimeStamp(value, expr, attributeType)
if (nonEqual) {
updateFilterForNonEqualTimeStamp(value, expr, filter)
} else {
filter
}
case _: IntegerType if valueType.sameType(DoubleType) =>
updateFilterForInt(value, expr, attributeType)
case _: ShortType if valueType.sameType(IntegerType) =>
updateFilterForShort(value, expr, attributeType)
case arr: ArrayType if !nonEqual =>
checkBinaryExpression(arr.elementType, value, valueType, nonEqual)
case _ => Some(transformExpression(expr))
}
}
def checkInValueList(attributeName: String,
list: Seq[SparkExpression],
newList: Seq[SparkExpression],
dt: SparkDataType,
hasNot: Boolean = true): Option[Expression] = {
if (!newList.equals(list)) {
val hSet = list.map(e => e.eval(EmptyRow))
if (hasNot) {
if (hSet.contains(null)) {
Some(new FalseExpression(translateColumn(attributeName, dt)))
} else {
Some(new NotInExpression(translateColumn(attributeName, dt),
new ListExpression(convertToJavaList(
hSet.map(f => translateLiteral(f, dt)).toList))))
}
} else {
if (hSet.length == 1 && hSet.head == null) {
Some(new FalseExpression(translateColumn(attributeName, dt)))
} else {
Some(new InExpression(translateColumn(attributeName, dt),
new ListExpression(convertToJavaList(hSet
.filterNot(_ == null)
.map(filterValues => translateLiteral(filterValues, dt))
.toList))))
}
}
} else {
Some(transformExpression(expr))
}
}
def checkInExpression(attribute: Attribute,
list: Seq[SparkExpression],
hasNot: Boolean = true): Option[Expression] = {
attribute.dataType match {
case _: DateType | _: TimestampType if list.head.dataType.sameType(StringType) =>
checkInValueList(attribute.name, list, typeCastStringToLongList(list, attribute.dataType),
attribute.dataType, hasNot)
case _: IntegerType if list.head.dataType.sameType(DoubleType) =>
checkInValueList(attribute.name, list, typeCastDoubleToIntList(list), attribute.dataType,
hasNot)
case _: ShortType if list.head.dataType.sameType(IntegerType) =>
checkInValueList(attribute.name, list, typeCastIntToShortList(list), attribute.dataType,
hasNot)
case _ => Some(transformExpression(expr))
}
}
expr match {
case EqualTo(Cast(a: Attribute, _), Literal(v, t)) =>
checkBinaryExpression(a.dataType, v, t, false)
case EqualTo(Literal(v, t), Cast(a: Attribute, _)) =>
checkBinaryExpression(a.dataType, v, t, false)
case Not(EqualTo(Cast(a: Attribute, _), Literal(v, t))) =>
checkBinaryExpression(a.dataType, v, t, false)
case Not(EqualTo(Literal(v, t), Cast(a: Attribute, _))) =>
checkBinaryExpression(a.dataType, v, t, false)
case Not(In(Cast(a: Attribute, _), list)) =>
checkInExpression(a, list)
case In(Cast(a: Attribute, _), list) =>
checkInExpression(a, list, false)
case GreaterThan(Cast(a: Attribute, _), Literal(v, t)) =>
checkBinaryExpression(a.dataType, v, t)
case GreaterThan(Literal(v, t), Cast(a: Attribute, _)) =>
checkBinaryExpression(a.dataType, v, t)
case LessThan(Cast(a: Attribute, _), Literal(v, t)) =>
checkBinaryExpression(a.dataType, v, t)
case LessThan(Literal(v, t), Cast(a: Attribute, _)) =>
checkBinaryExpression(a.dataType, v, t)
case GreaterThanOrEqual(Cast(a: Attribute, _), Literal(v, t)) =>
checkBinaryExpression(a.dataType, v, t)
case GreaterThanOrEqual(Literal(v, t), Cast(a: Attribute, _)) =>
checkBinaryExpression(a.dataType, v, t)
case LessThanOrEqual(Cast(a: Attribute, _), Literal(v, t)) =>
checkBinaryExpression(a.dataType, v, t)
case LessThanOrEqual(Literal(v, t), Cast(a: Attribute, _)) =>
checkBinaryExpression(a.dataType, v, t)
}
}