in datafusion/optimizer/src/unwrap_cast_in_comparison.rs [300:454]
fn try_cast_literal_to_type(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Result<Option<ScalarValue>> {
let lit_data_type = lit_value.get_datatype();
// the rule just support the signed numeric data type now
if !is_support_data_type(&lit_data_type) || !is_support_data_type(target_type) {
return Ok(None);
}
if lit_value.is_null() {
// null value can be cast to any type of null value
return Ok(Some(ScalarValue::try_from(target_type)?));
}
let mul = match target_type {
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64 => 1_i128,
DataType::Timestamp(_, _) => 1_i128,
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {other_type:?}"
)));
}
};
let (target_min, target_max) = match target_type {
DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
DataType::Decimal128(precision, _) => (
// Different precision for decimal128 can store different range of value.
// For example, the precision is 3, the max of value is `999` and the min
// value is `-999`
MIN_DECIMAL_FOR_EACH_PRECISION[*precision as usize - 1],
MAX_DECIMAL_FOR_EACH_PRECISION[*precision as usize - 1],
),
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {other_type:?}"
)));
}
};
let lit_value_target_type = match lit_value {
ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::Decimal128(Some(v), _, scale) => {
let lit_scale_mul = 10_i128.pow(*scale as u32);
if mul >= lit_scale_mul {
// Example:
// lit is decimal(123,3,2)
// target type is decimal(5,3)
// the lit can be converted to the decimal(1230,5,3)
(*v).checked_mul(mul / lit_scale_mul)
} else if (*v) % (lit_scale_mul / mul) == 0 {
// Example:
// lit is decimal(123000,10,3)
// target type is int32: the lit can be converted to INT32(123)
// target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
Some(*v / (lit_scale_mul / mul))
} else {
// can't convert the lit decimal to the target data type
None
}
}
other_value => {
return Err(DataFusionError::Internal(format!(
"Invalid literal value {other_value:?}"
)));
}
};
match lit_value_target_type {
None => Ok(None),
Some(value) => {
if value >= target_min && value <= target_max {
// the value casted from lit to the target type is in the range of target type.
// return the target type of scalar value
let result_scalar = match target_type {
DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
DataType::Timestamp(TimeUnit::Second, tz) => {
let value = cast_between_timestamp(
lit_data_type,
DataType::Timestamp(TimeUnit::Second, tz.clone()),
value,
);
ScalarValue::TimestampSecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
let value = cast_between_timestamp(
lit_data_type,
DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
value,
);
ScalarValue::TimestampMillisecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
let value = cast_between_timestamp(
lit_data_type,
DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
value,
);
ScalarValue::TimestampMicrosecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
let value = cast_between_timestamp(
lit_data_type,
DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
value,
);
ScalarValue::TimestampNanosecond(value, tz.clone())
}
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(value), *p, *s)
}
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {other_type:?}"
)));
}
};
Ok(Some(result_scalar))
} else {
Ok(None)
}
}
}
}