query/hash_lookup.cu (143 lines of code) (raw):
// Copyright (c) 2017-2018 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include <algorithm>
#include <exception>
#include <vector>
#include "query/transform.hpp"
#include "query/binder.hpp"
namespace ares {
class HashLookupContext {
public:
HashLookupContext(int indexVectorLength, CuckooHashIndex hashIndex,
RecordID* recordIDVector,
void *cudaStream)
: indexVectorLength(indexVectorLength),
hashIndex(hashIndex),
recordIDVector(recordIDVector),
cudaStream(reinterpret_cast<cudaStream_t>(cudaStream)) {}
template<typename InputIterator>
int run(uint32_t *indexVector, InputIterator inputIterator);
cudaStream_t getStream() const {
return cudaStream;
}
private:
int indexVectorLength;
CuckooHashIndex hashIndex;
RecordID* recordIDVector;
cudaStream_t cudaStream;
};
// Specialized for HashLookupContext.
template <>
class InputVectorBinder<HashLookupContext, 1> : public InputVectorBinderBase<
HashLookupContext, 1, 1> {
typedef InputVectorBinderBase<HashLookupContext, 1, 1> super_t;
public:
explicit InputVectorBinder(HashLookupContext context,
std::vector<InputVector> inputVectors,
uint32_t *indexVector, uint32_t *baseCounts,
uint32_t startCount) : super_t(context,
inputVectors,
indexVector,
baseCounts,
startCount) {
}
public:
template<typename ...InputIterators>
int bind(InputIterators... boundInputIterators);
};
} // namespace ares
CGoCallResHandle HashLookup(InputVector input, RecordID *output,
uint32_t *indexVector, int indexVectorLength,
uint32_t *baseCounts, uint32_t startCount,
CuckooHashIndex hashIndex, void *cudaStream,
int device) {
CGoCallResHandle resHandle = {nullptr, nullptr};
try {
#ifdef RUN_ON_DEVICE
cudaSetDevice(device);
#endif
ares::HashLookupContext ctx(indexVectorLength,
hashIndex, output, cudaStream);
std::vector<InputVector> inputVectors = {input};
ares::InputVectorBinder<ares::HashLookupContext, 1>
binder(ctx, inputVectors, indexVector, baseCounts, startCount);
resHandle.res =
reinterpret_cast<void *>(binder.bind());
CheckCUDAError("HashLookup");
} catch (std::exception &e) {
std::cerr << "Exception happened when doing HashLookup:" << e.what()
<< std::endl;
resHandle.pStrErr = strdup(e.what());
}
return resHandle;
}
namespace ares {
// Specialized version for hash lookup to support
// UUID.
template<typename ...InputIterators>
int InputVectorBinder<HashLookupContext, 1>::bind(
InputIterators... boundInputIterators) {
InputVector input = super_t::inputVectors[0];
uint32_t *indexVector = super_t::indexVector;
uint32_t *baseCounts = super_t::baseCounts;
uint32_t startCount = super_t::startCount;
if (input.Type == VectorPartyInput) {
InputVectorBinderBase<HashLookupContext, 1, 0>
nextBinder(context, inputVectors, indexVector, baseCounts, startCount);
VectorPartySlice inputVP = input.Vector.VP;
if (inputVP.DataType == UUID) {
uint8_t *basePtr = inputVP.BasePtr;
// Treat mode 0 as constant vector.
if (basePtr == nullptr) {
bool hasDefault = inputVP.DefaultValue.HasDefault;
DefaultValue defaultValue = inputVP.DefaultValue;
return nextBinder.bind(boundInputIterators...,
thrust::make_constant_iterator(
thrust::make_tuple<
UUIDT,
bool>(
defaultValue.Value.UUIDVal,
hasDefault)));
}
uint32_t nullsOffset = inputVP.NullsOffset;
uint32_t valuesOffset = inputVP.ValuesOffset;
uint8_t startingIndex = inputVP.StartingIndex;
uint8_t stepInBytes = getStepInBytes(inputVP.DataType);
uint32_t length = inputVP.Length;
return nextBinder.bind(boundInputIterators...,
make_column_iterator<UUIDT>(indexVector,
baseCounts,
startCount,
basePtr,
nullsOffset,
valuesOffset,
length,
stepInBytes,
startingIndex));
}
}
return super_t::bind(boundInputIterators...);
}
template<typename InputIterator>
int HashLookupContext::run(uint32_t *indexVector, InputIterator inputIter) {
typedef typename InputIterator::value_type::head_type InputValueType;
HashLookupFunctor<InputValueType> f(hashIndex.buckets, hashIndex.seeds,
hashIndex.keyBytes, hashIndex.numHashes,
hashIndex.numBuckets);
return thrust::transform(GET_EXECUTION_POLICY(cudaStream), inputIter,
inputIter + indexVectorLength, recordIDVector, f) -
recordIDVector;
}
} // namespace ares