velox/functions/prestosql/ArrayIntersectExcept.cpp (370 lines of code) (raw):
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/LambdaFunctionUtil.h"
namespace facebook::velox::functions {
namespace {
template <typename T>
struct SetWithNull {
SetWithNull(vector_size_t initialSetSize = kInitialSetSize) {
set.reserve(initialSetSize);
}
void reset() {
set.clear();
hasNull = false;
}
std::unordered_set<T> set;
bool hasNull{false};
static constexpr vector_size_t kInitialSetSize{128};
};
// Generates a set based on the elements of an ArrayVector. Note that we take
// rightSet as a parameter (instead of returning a new one) to reuse the
// allocated memory.
template <typename T, typename TVector>
void generateSet(
const ArrayVector* arrayVector,
const TVector* arrayElements,
vector_size_t idx,
SetWithNull<T>& rightSet) {
auto size = arrayVector->sizeAt(idx);
auto offset = arrayVector->offsetAt(idx);
rightSet.reset();
for (vector_size_t i = offset; i < (offset + size); ++i) {
if (arrayElements->isNullAt(i)) {
rightSet.hasNull = true;
} else {
// Function can be called with either FlatVector or DecodedVector, but
// their APIs are slightly different.
if constexpr (std::is_same_v<TVector, DecodedVector>) {
rightSet.set.insert(arrayElements->template valueAt<T>(i));
} else {
rightSet.set.insert(arrayElements->valueAt(i));
}
}
}
}
DecodedVector* decodeArrayElements(
exec::LocalDecodedVector& arrayDecoder,
exec::LocalDecodedVector& elementsDecoder,
const SelectivityVector& rows) {
auto decodedVector = arrayDecoder.get();
auto baseArrayVector = arrayDecoder->base()->as<ArrayVector>();
// Decode and acquire array elements vector.
auto elementsVector = baseArrayVector->elements();
auto elementsSelectivityRows = toElementRows(
elementsVector->size(), rows, baseArrayVector, decodedVector->indices());
elementsDecoder.get()->decode(*elementsVector, elementsSelectivityRows);
auto decodedElementsVector = elementsDecoder.get();
return decodedElementsVector;
}
// See documentation at https://prestodb.io/docs/current/functions/array.html
template <bool isIntersect, typename T>
class ArrayIntersectExceptFunction : public exec::VectorFunction {
public:
/// This class is used for both array_intersect and array_except functions
/// (behavior controlled at compile time by the isIntersect template
/// variable). Both these functions take two ArrayVectors as inputs (left and
/// right) and leverage two sets to calculate the intersection (or except):
///
/// - rightSet: a set that contains all (distinct) elements from the
/// right-hand side array.
/// - outputSet: a set that contains the elements that were already added to
/// the output (to prevent duplicates).
///
/// Along with each set, we maintain a `hasNull` flag that indicates whether
/// null is present in the arrays, to prevent the use of optional types or
/// special values.
///
/// Zero element copy:
///
/// In order to prevent copies of array elements, the function reuses the
/// internal elements() vector from the left-hand side ArrayVector.
///
/// First a new vector is created containing the indices of the elements
/// which will be present in the output, and wrapped into a DictionaryVector.
/// Next the `lengths` and `offsets` vectors that control where output arrays
/// start and end are wrapped into the output ArrayVector.
///
/// Constant optimization:
///
/// If any of the values passed to array_intersect() or rhs for array_except()
/// are constant (array literals) we create a set before instantiating the
/// object and pass as a constructor parameter (constantSet).
ArrayIntersectExceptFunction() = default;
explicit ArrayIntersectExceptFunction(
SetWithNull<T> constantSet,
bool isLeftConstant)
: constantSet_(std::move(constantSet)), isLeftConstant_(isLeftConstant) {}
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx* context,
VectorPtr* result) const override {
memory::MemoryPool* pool = context->pool();
BaseVector* left = args[0].get();
BaseVector* right = args[1].get();
// For array_intersect, if there's a constant input, then require it is on
// the right side. For array_except, the constant optimization only applies
// if the constant is on the rhs, so this swap is not applicable.
if constexpr (isIntersect) {
if (constantSet_.has_value() && isLeftConstant_) {
std::swap(left, right);
}
}
exec::LocalDecodedVector leftHolder(context, *left, rows);
auto decodedLeftArray = leftHolder.get();
auto baseLeftArray = decodedLeftArray->base()->as<ArrayVector>();
// Decode and acquire array elements vector.
exec::LocalDecodedVector leftElementsDecoder(context);
auto decodedLeftElements =
decodeArrayElements(leftHolder, leftElementsDecoder, rows);
auto leftElementsCount =
countElements<ArrayVector>(rows, *decodedLeftArray);
vector_size_t rowCount = left->size();
// Allocate new vectors for indices, nulls, length and offsets.
BufferPtr newIndices = allocateIndices(leftElementsCount, pool);
BufferPtr newElementNulls =
AlignedBuffer::allocate<bool>(leftElementsCount, pool, bits::kNotNull);
BufferPtr newLengths = allocateSizes(rowCount, pool);
BufferPtr newOffsets = allocateOffsets(rowCount, pool);
// Pointers and cursors to the raw data.
auto rawNewIndices = newIndices->asMutable<vector_size_t>();
auto rawNewElementNulls = newElementNulls->asMutable<uint64_t>();
auto rawNewLengths = newLengths->asMutable<vector_size_t>();
auto rawNewOffsets = newOffsets->asMutable<vector_size_t>();
vector_size_t indicesCursor = 0;
// Lambda that process each row. This is detached from the code so we can
// apply it differently based on whether the right-hand side set is constant
// or not.
auto processRow = [&](vector_size_t row,
const SetWithNull<T>& rightSet,
SetWithNull<T>& outputSet) {
auto idx = decodedLeftArray->index(row);
auto size = baseLeftArray->sizeAt(idx);
auto offset = baseLeftArray->offsetAt(idx);
outputSet.reset();
*rawNewOffsets = indicesCursor;
// Scans the array elements on the left-hand side.
for (vector_size_t i = offset; i < (offset + size); ++i) {
if (decodedLeftElements->isNullAt(i)) {
// For a NULL value not added to the output row yet, insert in
// array_intersect if it was found on the rhs (and not found in the
// case of array_except).
if (!outputSet.hasNull) {
bool setNull = false;
if constexpr (isIntersect) {
setNull = rightSet.hasNull;
} else {
setNull = !rightSet.hasNull;
}
if (setNull) {
bits::setNull(rawNewElementNulls, indicesCursor++, true);
outputSet.hasNull = true;
}
}
} else {
auto val = decodedLeftElements->valueAt<T>(i);
// For array_intersect, add the element if it is found (not found
// for array_except) in the right-hand side, and wasn't added already
// (check outputSet).
bool addValue = false;
if constexpr (isIntersect) {
addValue = rightSet.set.count(val) > 0;
} else {
addValue = rightSet.set.count(val) == 0;
}
if (addValue) {
auto it = outputSet.set.insert(val);
if (it.second) {
rawNewIndices[indicesCursor++] = i;
}
}
}
}
*rawNewLengths = indicesCursor - *rawNewOffsets;
++rawNewLengths;
++rawNewOffsets;
};
SetWithNull<T> outputSet;
// Optimized case when the right-hand side array is constant.
if (constantSet_.has_value()) {
rows.applyToSelected([&](vector_size_t row) {
processRow(row, *constantSet_, outputSet);
});
}
// General case when no arrays are constant and both sets need to be
// computed for each row.
else {
exec::LocalDecodedVector rightHolder(context, *right, rows);
// Decode and acquire array elements vector.
exec::LocalDecodedVector rightElementsHolder(context);
auto decodedRightElements =
decodeArrayElements(rightHolder, rightElementsHolder, rows);
SetWithNull<T> rightSet;
auto rightArrayVector = rightHolder.get()->base()->as<ArrayVector>();
rows.applyToSelected([&](vector_size_t row) {
auto idx = rightHolder.get()->index(row);
generateSet<T>(rightArrayVector, decodedRightElements, idx, rightSet);
processRow(row, rightSet, outputSet);
});
}
auto newElements = BaseVector::wrapInDictionary(
newElementNulls, newIndices, indicesCursor, baseLeftArray->elements());
auto resultArray = std::make_shared<ArrayVector>(
pool,
ARRAY(CppToType<T>::create()),
BufferPtr(nullptr),
rowCount,
newOffsets,
newLengths,
newElements,
0);
context->moveOrCopyResult(resultArray, rows, result);
}
// If one of the arrays is constant, this member will store a pointer to the
// set generated from its elements, which is calculated only once, before
// instantiating this object.
std::optional<SetWithNull<T>> constantSet_;
// If there's a `constantSet`, whether it refers to left or right-hand side.
const bool isLeftConstant_{false};
}; // class ArrayIntersectExcept
template <typename T>
class ArraysOverlapFunction : public exec::VectorFunction {
public:
ArraysOverlapFunction() = default;
ArraysOverlapFunction(SetWithNull<T> constantSet, bool isLeftConstant)
: constantSet_(std::move(constantSet)), isLeftConstant_(isLeftConstant) {}
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx* context,
VectorPtr* result) const override {
BaseVector* left = args[0].get();
BaseVector* right = args[1].get();
if (constantSet_.has_value() && isLeftConstant_) {
std::swap(left, right);
}
exec::LocalDecodedVector arrayDecoder(context, *left, rows);
exec::LocalDecodedVector elementsDecoder(context);
auto decodedLeftElements =
decodeArrayElements(arrayDecoder, elementsDecoder, rows);
auto decodedLeftArray = arrayDecoder.get();
auto baseLeftArray = decodedLeftArray->base()->as<ArrayVector>();
BaseVector::ensureWritable(rows, BOOLEAN(), context->pool(), result);
auto resultBoolVector = (*result)->template asFlatVector<bool>();
auto processRow = [&](auto row, const SetWithNull<T>& rightSet) {
auto idx = decodedLeftArray->index(row);
auto offset = baseLeftArray->offsetAt(idx);
auto size = baseLeftArray->sizeAt(idx);
bool hasNull = rightSet.hasNull;
for (auto i = offset; i < (offset + size); ++i) {
// For each element in the current row search for it in the rightSet.
if (decodedLeftElements->isNullAt(i)) {
// Arrays overlap function skips null values.
hasNull = true;
continue;
}
if (rightSet.set.count(decodedLeftElements->valueAt<T>(i)) > 0) {
// Found an overlapping element. Add to result set.
resultBoolVector->set(row, true);
return;
}
}
if (hasNull) {
// If encountered a NULL, insert NULL in the result.
resultBoolVector->setNull(row, true);
} else {
// If there is no overlap and no nulls, then insert false.
resultBoolVector->set(row, false);
}
};
if (constantSet_.has_value()) {
rows.applyToSelected(
[&](vector_size_t row) { processRow(row, *constantSet_); });
}
// General case when no arrays are constant and both sets need to be
// computed for each row.
else {
exec::LocalDecodedVector rightDecoder(context, *right, rows);
exec::LocalDecodedVector rightElementsDecoder(context);
auto decodedRightElements =
decodeArrayElements(rightDecoder, rightElementsDecoder, rows);
SetWithNull<T> rightSet;
auto baseRightArray = rightDecoder.get()->base()->as<ArrayVector>();
rows.applyToSelected([&](vector_size_t row) {
auto idx = rightDecoder.get()->index(row);
generateSet<T>(baseRightArray, decodedRightElements, idx, rightSet);
processRow(row, rightSet);
});
}
}
private:
// If one of the arrays is constant, this member will store a pointer to the
// set generated from its elements, which is calculated only once, before
// instantiating this object.
std::optional<SetWithNull<T>> constantSet_;
// If there's a `constantSet`, whether it refers to left or right-hand side.
const bool isLeftConstant_{false};
}; // class ArraysOverlapFunction
void validateMatchingArrayTypes(
const std::vector<exec::VectorFunctionArg>& inputArgs,
const std::string& name,
vector_size_t expectedArgCount) {
VELOX_USER_CHECK_EQ(
inputArgs.size(),
expectedArgCount,
"{} requires exactly {} parameters",
name,
expectedArgCount);
auto arrayType = inputArgs.front().type;
VELOX_USER_CHECK_EQ(
arrayType->kind(),
TypeKind::ARRAY,
"{} requires arguments of type ARRAY",
name);
for (auto& arg : inputArgs) {
VELOX_USER_CHECK(
arrayType->kindEquals(arg.type),
"{} function requires all arguments of the same type: {} vs. {}",
name,
arg.type->toString(),
arrayType->toString());
}
}
template <typename T>
SetWithNull<T> validateConstantVectorAndGenerateSet(
const BaseVector* baseVector) {
auto constantVector = baseVector->as<ConstantVector<velox::ComplexType>>();
auto constantArray = constantVector->as<ConstantVector<velox::ComplexType>>();
VELOX_CHECK_NOT_NULL(constantArray, "wrong constant type found");
VELOX_CHECK_NOT_NULL(constantVector, "wrong constant type found");
auto arrayVecPtr = constantVector->valueVector()->as<ArrayVector>();
VELOX_CHECK_NOT_NULL(arrayVecPtr, "wrong array literal type");
auto elementsAsFlatVector = arrayVecPtr->elements()->as<FlatVector<T>>();
VELOX_CHECK_NOT_NULL(
elementsAsFlatVector, "constant value must be encoded as flat");
auto idx = constantArray->index();
SetWithNull<T> constantSet;
generateSet<T>(arrayVecPtr, elementsAsFlatVector, idx, constantSet);
return constantSet;
}
template <bool isIntersect, TypeKind kind>
std::shared_ptr<exec::VectorFunction> createTypedArraysIntersectExcept(
const std::vector<exec::VectorFunctionArg>& inputArgs) {
VELOX_CHECK_EQ(inputArgs.size(), 2);
BaseVector* left = inputArgs[0].constantValue.get();
BaseVector* right = inputArgs[1].constantValue.get();
using T = typename TypeTraits<kind>::NativeType;
// No constant values.
if ((left == nullptr) && (right == nullptr)) {
return std::make_shared<ArrayIntersectExceptFunction<isIntersect, T>>();
}
// Constant optimization is not supported for constant lhs for array_except
const bool isLeftConstant = (left != nullptr);
if (isLeftConstant) {
if constexpr (!isIntersect) {
return std::make_shared<ArrayIntersectExceptFunction<isIntersect, T>>();
}
}
BaseVector* constantVector = isLeftConstant ? left : right;
SetWithNull<T> constantSet =
validateConstantVectorAndGenerateSet<T>(constantVector);
return std::make_shared<ArrayIntersectExceptFunction<isIntersect, T>>(
std::move(constantSet), isLeftConstant);
}
std::shared_ptr<exec::VectorFunction> createArrayIntersect(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs) {
validateMatchingArrayTypes(inputArgs, name, 2);
auto elementType = inputArgs.front().type->childAt(0);
return VELOX_DYNAMIC_SCALAR_TEMPLATE_TYPE_DISPATCH(
createTypedArraysIntersectExcept,
/* isIntersect */ true,
elementType->kind(),
inputArgs);
}
std::shared_ptr<exec::VectorFunction> createArrayExcept(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs) {
validateMatchingArrayTypes(inputArgs, name, 2);
auto elementType = inputArgs.front().type->childAt(0);
return VELOX_DYNAMIC_SCALAR_TEMPLATE_TYPE_DISPATCH(
createTypedArraysIntersectExcept,
/* isIntersect */ false,
elementType->kind(),
inputArgs);
}
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures(
const std::string& returnType) {
return {exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType(returnType)
.argumentType("array(T)")
.argumentType("array(T)")
.build()};
}
template <TypeKind kind>
const std::shared_ptr<exec::VectorFunction> createTypedArraysOverlap(
const std::vector<exec::VectorFunctionArg>& inputArgs) {
VELOX_CHECK_EQ(inputArgs.size(), 2);
auto left = inputArgs[0].constantValue.get();
auto right = inputArgs[1].constantValue.get();
using T = typename TypeTraits<kind>::NativeType;
if (left == nullptr && right == nullptr) {
return std::make_shared<ArraysOverlapFunction<T>>();
}
auto isLeftConstant = (left != nullptr);
auto baseVector = isLeftConstant ? left : right;
auto constantSet = validateConstantVectorAndGenerateSet<T>(baseVector);
return std::make_shared<ArraysOverlapFunction<T>>(
std::move(constantSet), isLeftConstant);
}
std::shared_ptr<exec::VectorFunction> createArraysOverlapFunction(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs) {
validateMatchingArrayTypes(inputArgs, name, 2);
auto elementType = inputArgs.front().type->childAt(0);
return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
createTypedArraysOverlap, elementType->kind(), inputArgs);
}
} // namespace
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_arrays_overlap,
signatures("boolean"),
createArraysOverlapFunction);
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_array_intersect,
signatures("array(T)"),
createArrayIntersect);
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_array_except,
signatures("array(T)"),
createArrayExcept);
} // namespace facebook::velox::functions