fn array_insert()

in native/spark-expr/src/array_funcs/array_insert.rs [196:312]


fn array_insert<O: OffsetSizeTrait>(
    list_array: &GenericListArray<O>,
    items_array: &ArrayRef,
    pos_array: &ArrayRef,
    legacy_mode: bool,
) -> DataFusionResult<ColumnarValue> {
    // The code is based on the implementation of the array_append from the Apache DataFusion
    // https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513
    //
    // This code is also based on the implementation of the array_insert from the Apache Spark
    // https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4713

    let values = list_array.values();
    let offsets = list_array.offsets();
    let values_data = values.to_data();
    let item_data = items_array.to_data();
    let new_capacity = Capacities::Array(values_data.len() + item_data.len());

    let mut mutable_values =
        MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity);

    let mut new_offsets = vec![O::usize_as(0)];
    let mut new_nulls = Vec::<bool>::with_capacity(list_array.len());

    let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions

    for (row_index, offset_window) in offsets.windows(2).enumerate() {
        let pos = pos_data.values()[row_index];
        let start = offset_window[0].as_usize();
        let end = offset_window[1].as_usize();
        let is_item_null = items_array.is_null(row_index);

        if list_array.is_null(row_index) {
            // In Spark if value of the array is NULL than nothing happens
            mutable_values.extend_nulls(1);
            new_offsets.push(new_offsets[row_index] + O::one());
            new_nulls.push(false);
            continue;
        }

        if pos == 0 {
            return Err(DataFusionError::Internal(
                "Position for array_insert should be greter or less than zero".to_string(),
            ));
        }

        if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) {
            let corrected_pos = if pos > 0 {
                (pos - 1).as_usize()
            } else {
                end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 }
            };
            let new_array_len = std::cmp::max(end - start + 1, corrected_pos);
            if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
                return Err(DataFusionError::Internal(format!(
                    "Max array length in Spark is {:?}, but got {:?}",
                    MAX_ROUNDED_ARRAY_LENGTH, new_array_len
                )));
            }

            if (start + corrected_pos) <= end {
                mutable_values.extend(0, start, start + corrected_pos);
                mutable_values.extend(1, row_index, row_index + 1);
                mutable_values.extend(0, start + corrected_pos, end);
                new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
            } else {
                mutable_values.extend(0, start, end);
                mutable_values.extend_nulls(new_array_len - (end - start));
                mutable_values.extend(1, row_index, row_index + 1);
                // In that case spark actualy makes array longer than expected;
                // For example, if pos is equal to 5, len is eq to 3, than resulted len will be 5
                new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one());
            }
        } else {
            // This comment is takes from the Apache Spark source code as is:
            // special case- if the new position is negative but larger than the current array size
            // place the new item at start of array, place the current array contents at the end
            // and fill the newly created array elements inbetween with a null
            let base_offset = if legacy_mode { 1 } else { 0 };
            let new_array_len = (-pos + base_offset).as_usize();
            if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
                return Err(DataFusionError::Internal(format!(
                    "Max array length in Spark is {:?}, but got {:?}",
                    MAX_ROUNDED_ARRAY_LENGTH, new_array_len
                )));
            }
            mutable_values.extend(1, row_index, row_index + 1);
            mutable_values.extend_nulls(new_array_len - (end - start + 1));
            mutable_values.extend(0, start, end);
            new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
        }
        if is_item_null {
            if (start == end) || (values.is_null(row_index)) {
                new_nulls.push(false)
            } else {
                new_nulls.push(true)
            }
        } else {
            new_nulls.push(true)
        }
    }

    let data = make_array(mutable_values.freeze());
    let data_type = match list_array.data_type() {
        DataType::List(field) => field.data_type(),
        DataType::LargeList(field) => field.data_type(),
        _ => unreachable!(),
    };
    let new_array = GenericListArray::<O>::try_new(
        Arc::new(Field::new("item", data_type.clone(), true)),
        OffsetBuffer::new(new_offsets.into()),
        data,
        Some(NullBuffer::new(new_nulls.into())),
    )?;

    Ok(ColumnarValue::Array(Arc::new(new_array)))
}