in native/spark-expr/src/array_funcs/array_insert.rs [196:312]
fn array_insert<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
items_array: &ArrayRef,
pos_array: &ArrayRef,
legacy_mode: bool,
) -> DataFusionResult<ColumnarValue> {
// The code is based on the implementation of the array_append from the Apache DataFusion
// https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513
//
// This code is also based on the implementation of the array_insert from the Apache Spark
// https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4713
let values = list_array.values();
let offsets = list_array.offsets();
let values_data = values.to_data();
let item_data = items_array.to_data();
let new_capacity = Capacities::Array(values_data.len() + item_data.len());
let mut mutable_values =
MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity);
let mut new_offsets = vec![O::usize_as(0)];
let mut new_nulls = Vec::<bool>::with_capacity(list_array.len());
let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions
for (row_index, offset_window) in offsets.windows(2).enumerate() {
let pos = pos_data.values()[row_index];
let start = offset_window[0].as_usize();
let end = offset_window[1].as_usize();
let is_item_null = items_array.is_null(row_index);
if list_array.is_null(row_index) {
// In Spark if value of the array is NULL than nothing happens
mutable_values.extend_nulls(1);
new_offsets.push(new_offsets[row_index] + O::one());
new_nulls.push(false);
continue;
}
if pos == 0 {
return Err(DataFusionError::Internal(
"Position for array_insert should be greter or less than zero".to_string(),
));
}
if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) {
let corrected_pos = if pos > 0 {
(pos - 1).as_usize()
} else {
end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 }
};
let new_array_len = std::cmp::max(end - start + 1, corrected_pos);
if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
return Err(DataFusionError::Internal(format!(
"Max array length in Spark is {:?}, but got {:?}",
MAX_ROUNDED_ARRAY_LENGTH, new_array_len
)));
}
if (start + corrected_pos) <= end {
mutable_values.extend(0, start, start + corrected_pos);
mutable_values.extend(1, row_index, row_index + 1);
mutable_values.extend(0, start + corrected_pos, end);
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
} else {
mutable_values.extend(0, start, end);
mutable_values.extend_nulls(new_array_len - (end - start));
mutable_values.extend(1, row_index, row_index + 1);
// In that case spark actualy makes array longer than expected;
// For example, if pos is equal to 5, len is eq to 3, than resulted len will be 5
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one());
}
} else {
// This comment is takes from the Apache Spark source code as is:
// special case- if the new position is negative but larger than the current array size
// place the new item at start of array, place the current array contents at the end
// and fill the newly created array elements inbetween with a null
let base_offset = if legacy_mode { 1 } else { 0 };
let new_array_len = (-pos + base_offset).as_usize();
if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
return Err(DataFusionError::Internal(format!(
"Max array length in Spark is {:?}, but got {:?}",
MAX_ROUNDED_ARRAY_LENGTH, new_array_len
)));
}
mutable_values.extend(1, row_index, row_index + 1);
mutable_values.extend_nulls(new_array_len - (end - start + 1));
mutable_values.extend(0, start, end);
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
}
if is_item_null {
if (start == end) || (values.is_null(row_index)) {
new_nulls.push(false)
} else {
new_nulls.push(true)
}
} else {
new_nulls.push(true)
}
}
let data = make_array(mutable_values.freeze());
let data_type = match list_array.data_type() {
DataType::List(field) => field.data_type(),
DataType::LargeList(field) => field.data_type(),
_ => unreachable!(),
};
let new_array = GenericListArray::<O>::try_new(
Arc::new(Field::new("item", data_type.clone(), true)),
OffsetBuffer::new(new_offsets.into()),
data,
Some(NullBuffer::new(new_nulls.into())),
)?;
Ok(ColumnarValue::Array(Arc::new(new_array)))
}