query/filter.cu (232 lines of code) (raw):

// Copyright (c) 2017-2018 Uber Technologies, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include <algorithm> #include <cstdio> #include <cstring> #include <exception> #include <vector> #include <initializer_list> #include "query/transform.hpp" #include "query/binder.hpp" #include "query/utils.hpp" namespace ares { // FilterContext is doing the actual filter after binding one or two // input iterators. template<typename FunctorType> class FilterContext { public: FilterContext( uint8_t *predicateVector, int indexVectorLength, RecordID **foreignTableRecordIDVectors, int numForeignTables, FunctorType functorType, void *cudaStream) : predicateVector(predicateVector), indexVectorLength(indexVectorLength), foreignTableRecordIDVectors(foreignTableRecordIDVectors), numForeignTables(numForeignTables), functorType(functorType), cudaStream(reinterpret_cast<cudaStream_t>(cudaStream)) {} cudaStream_t getStream() const { return cudaStream; } template<typename InputIterator> int run(uint32_t *indexVector, InputIterator inputIterator) { switch (numForeignTables) { #define EXECUTE_UNARY_REMOVE_IF(NumTotalForeignTables) \ case NumTotalForeignTables: { \ IndexZipIteratorMaker<NumTotalForeignTables> maker; \ return executeRemoveIf(inputIterator, \ maker.make(indexVector, \ foreignTableRecordIDVectors)); \ } EXECUTE_UNARY_REMOVE_IF(0) EXECUTE_UNARY_REMOVE_IF(1) EXECUTE_UNARY_REMOVE_IF(2) EXECUTE_UNARY_REMOVE_IF(3) EXECUTE_UNARY_REMOVE_IF(4) EXECUTE_UNARY_REMOVE_IF(5) EXECUTE_UNARY_REMOVE_IF(6) EXECUTE_UNARY_REMOVE_IF(7) EXECUTE_UNARY_REMOVE_IF(8) default:throw std::invalid_argument( "only support up to 8 foreign tables"); } } template<typename LHSIterator, typename RHSIterator> typename std::enable_if< supported_binary_combination<LHSIterator, RHSIterator>::value, int>::type run(uint32_t *indexVector, LHSIterator lhsIter, RHSIterator rhsIter) { switch (numForeignTables) { #define EXECUTE_BINARY_REMOVE_IF(NumTotalForeignTables) \ case NumTotalForeignTables: { \ IndexZipIteratorMaker<NumTotalForeignTables> maker; \ return executeRemoveIf(lhsIter, rhsIter, maker.make(indexVector, \ foreignTableRecordIDVectors)); \ } EXECUTE_BINARY_REMOVE_IF(0) EXECUTE_BINARY_REMOVE_IF(1) EXECUTE_BINARY_REMOVE_IF(2) EXECUTE_BINARY_REMOVE_IF(3) EXECUTE_BINARY_REMOVE_IF(4) EXECUTE_BINARY_REMOVE_IF(5) EXECUTE_BINARY_REMOVE_IF(6) EXECUTE_BINARY_REMOVE_IF(7) EXECUTE_BINARY_REMOVE_IF(8) default:throw std::invalid_argument( "only support up to 8 foreign tables"); } } template<typename LHSIterator, typename RHSIterator> typename std::enable_if< !supported_binary_combination<LHSIterator, RHSIterator>::value, int>::type run(uint32_t *indexVector, LHSIterator lhsIter, RHSIterator rhsIter) { throw std::invalid_argument( std::string("Unsupported data type combination ") + __PRETTY_FUNCTION__ + ", " + __FILE__ + ": " + std::to_string(__LINE__) + " in filter context"); } private: uint8_t *predicateVector; int indexVectorLength; RecordID **foreignTableRecordIDVectors; int numForeignTables; FunctorType functorType; cudaStream_t cudaStream; template<typename LHSIterator, typename RHSIterator, typename IndexZipIterator> int executeRemoveIf(LHSIterator lhsIter, RHSIterator rhsIter, IndexZipIterator indexZipIterator); template<typename InputIterator, typename IndexZipIterator> int executeRemoveIf(InputIterator inputIter, IndexZipIterator indexZipIterator); }; } // namespace ares CGoCallResHandle UnaryFilter(InputVector input, uint32_t *indexVector, uint8_t *predicateVector, int indexVectorLength, RecordID **foreignTableRecordIDVectors, int numForeignTables, uint32_t *baseCounts, uint32_t startCount, UnaryFunctorType functorType, void *cudaStream, int device) { CGoCallResHandle resHandle = {nullptr, nullptr}; try { #ifdef RUN_ON_DEVICE cudaSetDevice(device); #endif ares::FilterContext<UnaryFunctorType> ctx(predicateVector, indexVectorLength, foreignTableRecordIDVectors, numForeignTables, functorType, cudaStream); std::vector<InputVector> inputVectors = {input}; ares::InputVectorBinder<ares::FilterContext<UnaryFunctorType>, 1> binder(ctx, inputVectors, indexVector, baseCounts, startCount); resHandle.res = reinterpret_cast<void *>(binder.bind()); CheckCUDAError("UnaryFilter"); } catch (std::exception &e) { std::cerr << "Exception happend when doing UnaryFilter:" << e.what() << std::endl; resHandle.pStrErr = strdup(e.what()); } return resHandle; } CGoCallResHandle BinaryFilter(InputVector lhs, InputVector rhs, uint32_t *indexVector, uint8_t *predicateVector, int indexVectorLength, RecordID **foreignTableRecordIDVectors, int numForeignTables, uint32_t *baseCounts, uint32_t startCount, BinaryFunctorType functorType, void *cudaStream, int device) { CGoCallResHandle resHandle = {nullptr, nullptr}; try { #ifdef RUN_ON_DEVICE cudaSetDevice(device); #endif ares::FilterContext<BinaryFunctorType> ctx(predicateVector, indexVectorLength, foreignTableRecordIDVectors, numForeignTables, functorType, cudaStream); std::vector<InputVector> inputVectors = {lhs, rhs}; ares::InputVectorBinder<ares::FilterContext<BinaryFunctorType>, 2> binder( ctx, inputVectors, indexVector, baseCounts, startCount); resHandle.res = reinterpret_cast<void *>(binder.bind()); CheckCUDAError("BinaryFilter"); } catch (std::exception &e) { std::cerr << "Exception happend when doing BinaryFilter:" << e.what() << std::endl; resHandle.pStrErr = strdup(e.what()); } return resHandle; } namespace ares { // Filter template function for unary transform filter. template<typename FunctorType> template<typename InputIterator, typename IndexZipIterator> int FilterContext<FunctorType>::executeRemoveIf( InputIterator inputIter, IndexZipIterator indexZipIterator) { typedef typename InputIterator::value_type::head_type InputValueType; UnaryPredicateFunctor<bool, InputValueType> f(functorType); RemoveFilter<typename IndexZipIterator::value_type, uint8_t> removeFilter( predicateVector); // first compute the predicate values. thrust::transform(GET_EXECUTION_POLICY(cudaStream), inputIter, inputIter + indexVectorLength, predicateVector, f); // then we use the predicate values to remove indexes in place. return thrust::remove_if(GET_EXECUTION_POLICY(cudaStream), indexZipIterator, indexZipIterator + indexVectorLength, removeFilter) - indexZipIterator; } // run binary filter. template<typename FunctorType> template<typename LHSIterator, typename RHSIterator, typename IndexZipIterator> int FilterContext<FunctorType>::executeRemoveIf( LHSIterator lhsIter, RHSIterator rhsIter, IndexZipIterator indexZipIterator) { typedef typename input_iterator_value_type< typename LHSIterator::value_type::head_type, typename RHSIterator::value_type::head_type>::type InputValueType1; typedef typename input_iterator_value_type< typename RHSIterator::value_type::head_type, typename LHSIterator::value_type::head_type>::type InputValueType2; BinaryPredicateFunctor<bool, InputValueType1, InputValueType2> f(functorType); RemoveFilter<typename IndexZipIterator::value_type, uint8_t> removeFilter( predicateVector); // first compute the predicate values. thrust::transform(GET_EXECUTION_POLICY(cudaStream), lhsIter, lhsIter + indexVectorLength, rhsIter, predicateVector, f); // then we use the predicate values to remove indexes in place. return thrust::remove_if(GET_EXECUTION_POLICY(cudaStream), indexZipIterator, indexZipIterator + indexVectorLength, removeFilter) - indexZipIterator; } } // namespace ares