native/spark-expr/src/array_funcs/array_repeat.rs (147 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. use arrow::array::{ new_null_array, Array, ArrayRef, Capacities, GenericListArray, ListArray, MutableArrayData, NullBufferBuilder, OffsetSizeTrait, UInt64Array, }; use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::compute::cast; use arrow::datatypes::DataType::{LargeList, List}; use arrow::datatypes::{DataType, Field}; use datafusion::common::cast::{as_large_list_array, as_list_array, as_uint64_array}; use datafusion::common::{exec_err, DataFusionError, ScalarValue}; use datafusion::logical_expr::ColumnarValue; use std::sync::Arc; pub fn make_scalar_function<F>( inner: F, ) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> where F: Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError>, { move |args: &[ColumnarValue]| { // first, identify if any of the arguments is an Array. If yes, store its `len`, // as any scalar will need to be converted to an array of len `len`. let len = args .iter() .fold(Option::<usize>::None, |acc, arg| match arg { ColumnarValue::Scalar(_) => acc, ColumnarValue::Array(a) => Some(a.len()), }); let is_scalar = len.is_none(); let args = ColumnarValue::values_to_arrays(args)?; let result = (inner)(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); result.map(ColumnarValue::Scalar) } else { result.map(ColumnarValue::Array) } } } pub fn spark_array_repeat(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> { make_scalar_function(spark_array_repeat_inner)(args) } /// Array_repeat SQL function fn spark_array_repeat_inner(args: &[ArrayRef]) -> datafusion::common::Result<ArrayRef> { let element = &args[0]; let count_array = &args[1]; let count_array = match count_array.data_type() { DataType::Int64 => &cast(count_array, &DataType::UInt64)?, DataType::UInt64 => count_array, _ => return exec_err!("count must be an integer type"), }; let count_array = as_uint64_array(count_array)?; match element.data_type() { List(_) => { let list_array = as_list_array(element)?; general_list_repeat::<i32>(list_array, count_array) } LargeList(_) => { let list_array = as_large_list_array(element)?; general_list_repeat::<i64>(list_array, count_array) } _ => general_repeat::<i32>(element, count_array), } } /// For each element of `array[i]` repeat `count_array[i]` times. /// /// Assumption for the input: /// 1. `count[i] >= 0` /// 2. `array.len() == count_array.len()` /// /// For example, /// ```text /// array_repeat( /// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]] /// ) /// ``` fn general_repeat<O: OffsetSizeTrait>( array: &ArrayRef, count_array: &UInt64Array, ) -> datafusion::common::Result<ArrayRef> { let data_type = array.data_type(); let mut new_values = vec![]; let count_vec = count_array .values() .to_vec() .iter() .map(|x| *x as usize) .collect::<Vec<_>>(); let mut nulls = NullBufferBuilder::new(count_array.len()); for (row_index, &count) in count_vec.iter().enumerate() { nulls.append(!count_array.is_null(row_index)); let repeated_array = if array.is_null(row_index) { new_null_array(data_type, count) } else { let original_data = array.to_data(); let capacity = Capacities::Array(count); let mut mutable = MutableArrayData::with_capacities(vec![&original_data], false, capacity); for _ in 0..count { mutable.extend(0, row_index, row_index + 1); } let data = mutable.freeze(); arrow::array::make_array(data) }; new_values.push(repeated_array); } let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); let values = compute::concat(&new_values)?; Ok(Arc::new(GenericListArray::<O>::try_new( Arc::new(Field::new_list_field(data_type.to_owned(), true)), OffsetBuffer::from_lengths(count_vec), values, nulls.finish(), )?)) } /// Handle List version of `general_repeat` /// /// For each element of `list_array[i]` repeat `count_array[i]` times. /// /// For example, /// ```text /// array_repeat( /// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]] /// ) /// ``` fn general_list_repeat<O: OffsetSizeTrait>( list_array: &GenericListArray<O>, count_array: &UInt64Array, ) -> datafusion::common::Result<ArrayRef> { let data_type = list_array.data_type(); let value_type = list_array.value_type(); let mut new_values = vec![]; let count_vec = count_array .values() .to_vec() .iter() .map(|x| *x as usize) .collect::<Vec<_>>(); for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { let list_arr = match list_array_row { Some(list_array_row) => { let original_data = list_array_row.to_data(); let capacity = Capacities::Array(original_data.len() * count); let mut mutable = MutableArrayData::with_capacities(vec![&original_data], false, capacity); for _ in 0..count { mutable.extend(0, 0, original_data.len()); } let data = mutable.freeze(); let repeated_array = arrow::array::make_array(data); let list_arr = GenericListArray::<O>::try_new( Arc::new(Field::new_list_field(value_type.clone(), true)), OffsetBuffer::<O>::from_lengths(vec![original_data.len(); count]), repeated_array, None, )?; Arc::new(list_arr) as ArrayRef } None => new_null_array(data_type, count), }; new_values.push(list_arr); } let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>(); let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); let values = compute::concat(&new_values)?; Ok(Arc::new(ListArray::try_new( Arc::new(Field::new_list_field(data_type.to_owned(), true)), OffsetBuffer::<i32>::from_lengths(lengths), values, None, )?)) }