datafusion/spark/src/function/math/expm1.rs (114 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. use crate::function::error_utils::{ invalid_arg_count_exec_err, unsupported_data_type_exec_err, }; use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::{DataType, Float64Type}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use std::any::Any; use std::sync::Arc; /// <https://spark.apache.org/docs/latest/api/sql/index.html#expm1> #[derive(Debug)] pub struct SparkExpm1 { signature: Signature, aliases: Vec<String>, } impl Default for SparkExpm1 { fn default() -> Self { Self::new() } } impl SparkExpm1 { pub fn new() -> Self { Self { signature: Signature::user_defined(Volatility::Immutable), aliases: vec![], } } } impl ScalarUDFImpl for SparkExpm1 { fn as_any(&self) -> &dyn Any { self } fn name(&self) -> &str { "expm1" } fn signature(&self) -> &Signature { &self.signature } fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { Ok(DataType::Float64) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { if args.args.len() != 1 { return Err(invalid_arg_count_exec_err("expm1", (1, 1), args.args.len())); } match &args.args[0] { ColumnarValue::Scalar(ScalarValue::Float64(value)) => Ok( ColumnarValue::Scalar(ScalarValue::Float64(value.map(|x| x.exp_m1()))), ), ColumnarValue::Array(array) => match array.data_type() { DataType::Float64 => Ok(ColumnarValue::Array(Arc::new( array .as_primitive::<Float64Type>() .unary::<_, Float64Type>(|x| x.exp_m1()), ) as ArrayRef)), other => Err(unsupported_data_type_exec_err( "expm1", format!("{}", DataType::Float64).as_str(), other, )), }, other => Err(unsupported_data_type_exec_err( "expm1", format!("{}", DataType::Float64).as_str(), &other.data_type(), )), } } fn aliases(&self) -> &[String] { &self.aliases } fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { if arg_types.len() != 1 { return Err(invalid_arg_count_exec_err("expm1", (1, 1), arg_types.len())); } if arg_types[0].is_numeric() { Ok(vec![DataType::Float64]) } else { Err(unsupported_data_type_exec_err( "expm1", "Numeric Type", &arg_types[0], )) } } } #[cfg(test)] mod tests { use crate::function::math::expm1::SparkExpm1; use crate::function::utils::test::test_scalar_function; use arrow::array::{Array, Float64Array}; use arrow::datatypes::DataType::Float64; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; macro_rules! test_expm1_float64_invoke { ($INPUT:expr, $EXPECTED:expr) => { test_scalar_function!( SparkExpm1::new(), vec![ColumnarValue::Scalar(ScalarValue::Float64($INPUT))], $EXPECTED, f64, Float64, Float64Array ); }; } #[test] fn test_expm1_invoke() -> Result<()> { test_expm1_float64_invoke!(Some(0f64), Ok(Some(0.0f64))); Ok(()) } }