in sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala [205:320]
private def simplifyNumericComparison(
exp: BinaryComparison,
fromExp: Expression,
toType: NumericType,
value: Any): Expression = {
val fromType = fromExp.dataType
val ordering = PhysicalDataType.ordering(toType)
val range = getRange(fromType)
if (range.isDefined) {
val (min, max) = range.get
val (minInToType, maxInToType) = {
(Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
}
val minCmp = ordering.compare(value, minInToType)
val maxCmp = ordering.compare(value, maxInToType)
if (maxCmp >= 0 || minCmp <= 0) {
return if (maxCmp > 0) {
exp match {
case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
falseIfNotNull(fromExp)
case LessThan(_, _) | LessThanOrEqual(_, _) =>
trueIfNotNull(fromExp)
// make sure the expression is evaluated if it is non-deterministic
case EqualNullSafe(_, _) if exp.deterministic =>
FalseLiteral
case _ => exp
}
} else if (maxCmp == 0) {
exp match {
case GreaterThan(_, _) =>
falseIfNotNull(fromExp)
case LessThanOrEqual(_, _) =>
trueIfNotNull(fromExp)
case LessThan(_, _) =>
Not(EqualTo(fromExp, Literal(max, fromType)))
case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
EqualTo(fromExp, Literal(max, fromType))
case EqualNullSafe(_, _) =>
EqualNullSafe(fromExp, Literal(max, fromType))
case _ => exp
}
} else if (minCmp < 0) {
exp match {
case GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
trueIfNotNull(fromExp)
case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) =>
falseIfNotNull(fromExp)
// make sure the expression is evaluated if it is non-deterministic
case EqualNullSafe(_, _) if exp.deterministic =>
FalseLiteral
case _ => exp
}
} else { // minCmp == 0
exp match {
case LessThan(_, _) =>
falseIfNotNull(fromExp)
case GreaterThanOrEqual(_, _) =>
trueIfNotNull(fromExp)
case GreaterThan(_, _) =>
Not(EqualTo(fromExp, Literal(min, fromType)))
case LessThanOrEqual(_, _) | EqualTo(_, _) =>
EqualTo(fromExp, Literal(min, fromType))
case EqualNullSafe(_, _) =>
EqualNullSafe(fromExp, Literal(min, fromType))
case _ => exp
}
}
}
}
// When we reach to this point, it means either there is no min/max for the `fromType` (e.g.,
// decimal type), or that the literal `value` is within range `(min, max)`. For these, we
// optimize by moving the cast to the literal side.
val newValue = Cast(Literal(value), fromType, ansiEnabled = false).eval()
if (newValue == null) {
// This means the cast failed, for instance, due to the value is not representable in the
// narrower type. In this case we simply return the original expression.
return exp
}
val valueRoundTrip = Cast(Literal(newValue, fromType), toType).eval()
val lit = Literal(newValue, fromType)
val cmp = ordering.compare(value, valueRoundTrip)
if (cmp == 0) {
exp match {
case GreaterThan(_, _) => GreaterThan(fromExp, lit)
case GreaterThanOrEqual(_, _) => GreaterThanOrEqual(fromExp, lit)
case EqualTo(_, _) => EqualTo(fromExp, lit)
case EqualNullSafe(_, _) => EqualNullSafe(fromExp, lit)
case LessThan(_, _) => LessThan(fromExp, lit)
case LessThanOrEqual(_, _) => LessThanOrEqual(fromExp, lit)
case _ => exp
}
} else if (cmp < 0) {
// This means the literal value is rounded up after casting to `fromType`
exp match {
case EqualTo(_, _) => falseIfNotNull(fromExp)
case EqualNullSafe(_, _) if fromExp.deterministic => FalseLiteral
case GreaterThan(_, _) | GreaterThanOrEqual(_, _) => GreaterThanOrEqual(fromExp, lit)
case LessThan(_, _) | LessThanOrEqual(_, _) => LessThan(fromExp, lit)
case _ => exp
}
} else {
// This means the literal value is rounded down after casting to `fromType`
exp match {
case EqualTo(_, _) => falseIfNotNull(fromExp)
case EqualNullSafe(_, _) => FalseLiteral
case GreaterThan(_, _) | GreaterThanOrEqual(_, _) => GreaterThan(fromExp, lit)
case LessThan(_, _) | LessThanOrEqual(_, _) => LessThanOrEqual(fromExp, lit)
case _ => exp
}
}
}