velox/exec/MergeJoin.cpp (510 lines of code) (raw):

/* * Copyright (c) Facebook, Inc. and its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "velox/exec/MergeJoin.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/expression/ControlExpr.h" namespace facebook::velox::exec { MergeJoin::MergeJoin( int32_t operatorId, DriverCtx* driverCtx, const std::shared_ptr<const core::MergeJoinNode>& joinNode) : Operator( driverCtx, joinNode->outputType(), operatorId, joinNode->id(), "MergeJoin"), outputBatchSize_{driverCtx->queryConfig().preferredOutputBatchSize()}, joinType_{joinNode->joinType()}, numKeys_{joinNode->leftKeys().size()} { VELOX_USER_CHECK( joinNode->isInnerJoin() || joinNode->isLeftJoin(), "Merge join supports only inner and left joins. Other join types are not supported yet."); leftKeys_.reserve(numKeys_); rightKeys_.reserve(numKeys_); auto leftType = joinNode->sources()[0]->outputType(); for (auto& key : joinNode->leftKeys()) { leftKeys_.push_back(leftType->getChildIdx(key->name())); } auto rightType = joinNode->sources()[1]->outputType(); for (auto& key : joinNode->rightKeys()) { rightKeys_.push_back(rightType->getChildIdx(key->name())); } for (auto i = 0; i < leftType->size(); ++i) { auto name = leftType->nameOf(i); auto outIndex = outputType_->getChildIdxIfExists(name); if (outIndex.has_value()) { leftProjections_.emplace_back(i, outIndex.value()); } } for (auto i = 0; i < rightType->size(); ++i) { auto name = rightType->nameOf(i); auto outIndex = outputType_->getChildIdxIfExists(name); if (outIndex.has_value()) { rightProjections_.emplace_back(i, outIndex.value()); } } if (joinNode->filter()) { initializeFilter(joinNode->filter(), leftType, rightType); if (joinNode->isLeftJoin()) { leftJoinTracker_ = LeftJoinTracker(outputBatchSize_, pool()); } } } void MergeJoin::initializeFilter( const std::shared_ptr<const core::ITypedExpr>& filter, const RowTypePtr& leftType, const RowTypePtr& rightType) { std::vector<std::shared_ptr<const core::ITypedExpr>> filters = {filter}; filter_ = std::make_unique<ExprSet>(std::move(filters), operatorCtx_->execCtx()); ChannelIndex filterChannel = 0; std::vector<std::string> names; std::vector<TypePtr> types; auto numFields = filter_->expr(0)->distinctFields().size(); names.reserve(numFields); types.reserve(numFields); for (const auto& field : filter_->expr(0)->distinctFields()) { const auto& name = field->field(); auto channel = leftType->getChildIdxIfExists(name); if (channel.has_value()) { auto channelValue = channel.value(); filterLeftInputs_.emplace_back(channelValue, filterChannel++); names.emplace_back(leftType->nameOf(channelValue)); types.emplace_back(leftType->childAt(channelValue)); continue; } channel = rightType->getChildIdxIfExists(name); if (channel.has_value()) { auto channelValue = channel.value(); filterRightInputs_.emplace_back(channelValue, filterChannel++); names.emplace_back(rightType->nameOf(channelValue)); types.emplace_back(rightType->childAt(channelValue)); continue; } VELOX_FAIL( "Merge join filter field not found in either left or right input: {}", field->toString()); } filterInputType_ = ROW(std::move(names), std::move(types)); } BlockingReason MergeJoin::isBlocked(ContinueFuture* future) { if (future_.valid()) { *future = std::move(future_); return BlockingReason::kWaitForExchange; } return BlockingReason::kNotBlocked; } bool MergeJoin::needsInput() const { return input_ == nullptr; } void MergeJoin::addInput(RowVectorPtr input) { input_ = std::move(input); index_ = 0; if (leftJoinTracker_) { leftJoinTracker_->resetLastVector(); } } // static int32_t MergeJoin::compare( const std::vector<ChannelIndex>& keys, const RowVectorPtr& batch, vector_size_t index, const std::vector<ChannelIndex>& otherKeys, const RowVectorPtr& otherBatch, vector_size_t otherIndex) { for (auto i = 0; i < keys.size(); ++i) { auto compare = batch->childAt(keys[i])->compare( otherBatch->childAt(otherKeys[i]).get(), index, otherIndex); if (compare != 0) { return compare; } } return 0; } bool MergeJoin::findEndOfMatch( Match& match, const RowVectorPtr& input, const std::vector<ChannelIndex>& keys) { if (match.complete) { return true; } auto prevInput = match.inputs.back(); auto prevIndex = prevInput->size() - 1; auto numInput = input->size(); vector_size_t endIndex = 0; while (endIndex < numInput && compare(keys, input, endIndex, keys, prevInput, prevIndex) == 0) { ++endIndex; } if (endIndex == numInput) { // Inputs are kept past getting a new batch of inputs. LazyVectors // must be loaded before advancing to the next batch. loadColumns(input, *operatorCtx_->execCtx()); match.inputs.push_back(input); match.endIndex = endIndex; return false; } if (endIndex > 0) { // Match ends here, no need to pre-load lazies. match.inputs.push_back(input); match.endIndex = endIndex; } match.complete = true; return true; } namespace { void copyRow( const RowVectorPtr& source, vector_size_t sourceIndex, const RowVectorPtr& target, vector_size_t targetIndex, const std::vector<IdentityProjection>& projections) { for (auto& projection : projections) { auto sourceChild = source->childAt(projection.inputChannel); auto targetChild = target->childAt(projection.outputChannel); targetChild->copy(sourceChild.get(), targetIndex, sourceIndex, 1); } } } // namespace void MergeJoin::addOutputRowForLeftJoin() { copyRow(input_, index_, output_, outputSize_, leftProjections_); for (auto& projection : rightProjections_) { auto target = output_->childAt(projection.outputChannel); target->setNull(outputSize_, true); } if (leftJoinTracker_) { // Record left-side row with no match on the right side. leftJoinTracker_->addMiss(outputSize_); } ++outputSize_; } void MergeJoin::addOutputRow( const RowVectorPtr& left, vector_size_t leftIndex, const RowVectorPtr& right, vector_size_t rightIndex) { copyRow(left, leftIndex, output_, outputSize_, leftProjections_); copyRow(right, rightIndex, output_, outputSize_, rightProjections_); if (filter_) { // TODO Re-use output_ columns when possible. copyRow(left, leftIndex, filterInput_, outputSize_, filterLeftInputs_); copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_); if (leftJoinTracker_) { // Record left-side row with a match on the right-side. leftJoinTracker_->addMatch(left, leftIndex, outputSize_); } } ++outputSize_; } void MergeJoin::prepareOutput() { if (output_ == nullptr) { std::vector<VectorPtr> localColumns(outputType_->size()); for (auto i = 0; i < outputType_->size(); ++i) { localColumns[i] = BaseVector::create( outputType_->childAt(i), outputBatchSize_, operatorCtx_->pool()); } output_ = std::make_shared<RowVector>( operatorCtx_->pool(), outputType_, nullptr, outputBatchSize_, std::move(localColumns)); outputSize_ = 0; if (filterInput_ != nullptr) { // When filterInput_ contains array or map columns, their child vectors // (elements, keys and values) keep growing after each call to // 'copyRow'. Call BaseVector::resize(0) on these child vectors to avoid // that. // TODO Refactor this logic into a method on BaseVector. for (auto& child : filterInput_->children()) { if (child->typeKind() == TypeKind::ARRAY) { child->as<ArrayVector>()->elements()->resize(0); } else if (child->typeKind() == TypeKind::MAP) { auto* mapChild = child->as<MapVector>(); mapChild->mapKeys()->resize(0); mapChild->mapValues()->resize(0); } } } } if (filter_ != nullptr && filterInput_ == nullptr) { std::vector<VectorPtr> inputs(filterInputType_->size()); for (auto i = 0; i < filterInputType_->size(); ++i) { inputs[i] = BaseVector::create( filterInputType_->childAt(i), outputBatchSize_, operatorCtx_->pool()); } filterInput_ = std::make_shared<RowVector>( operatorCtx_->pool(), filterInputType_, nullptr, outputBatchSize_, std::move(inputs)); } } bool MergeJoin::addToOutput() { prepareOutput(); size_t firstLeftBatch; vector_size_t leftStartIndex; if (leftMatch_->cursor) { firstLeftBatch = leftMatch_->cursor->batchIndex; leftStartIndex = leftMatch_->cursor->index; } else { firstLeftBatch = 0; leftStartIndex = leftMatch_->startIndex; } size_t numLefts = leftMatch_->inputs.size(); for (size_t l = firstLeftBatch; l < numLefts; ++l) { auto left = leftMatch_->inputs[l]; auto leftStart = l == firstLeftBatch ? leftStartIndex : 0; auto leftEnd = l == numLefts - 1 ? leftMatch_->endIndex : left->size(); for (auto i = leftStart; i < leftEnd; ++i) { auto firstRightBatch = (l == firstLeftBatch && i == leftStart && rightMatch_->cursor) ? rightMatch_->cursor->batchIndex : 0; auto rightStartIndex = (l == firstLeftBatch && i == leftStart && rightMatch_->cursor) ? rightMatch_->cursor->index : rightMatch_->startIndex; auto numRights = rightMatch_->inputs.size(); for (size_t r = firstRightBatch; r < numRights; ++r) { auto right = rightMatch_->inputs[r]; auto rightStart = r == firstRightBatch ? rightStartIndex : 0; auto rightEnd = r == numRights - 1 ? rightMatch_->endIndex : right->size(); for (auto j = rightStart; j < rightEnd; ++j) { if (outputSize_ == outputBatchSize_) { leftMatch_->setCursor(l, i); rightMatch_->setCursor(r, j); return true; } addOutputRow(left, i, right, j); } } } } leftMatch_.reset(); rightMatch_.reset(); return outputSize_ == outputBatchSize_; } RowVectorPtr MergeJoin::getOutput() { // Make sure to have is-blocked or needs-input as true if returning null // output. Otherwise, Driver assumes the operator is finished. // Use Operator::noMoreInput() as a no-more-input-on-the-left indicator and a // noMoreRightInput_ flag as no-more-input-on-the-right indicator. // TODO Finish early if ran out of data on either side of the join. for (;;) { auto output = doGetOutput(); if (output != nullptr) { if (filter_) { output = applyFilter(output); if (output != nullptr) { return output; } // No rows survived the filter. Get more rows. continue; } else { return output; } } // Check if we need to get more data from the right side. if (!noMoreRightInput_ && !future_.valid() && !rightInput_) { if (!rightSource_) { rightSource_ = operatorCtx_->task()->getMergeJoinSource( operatorCtx_->driverCtx()->splitGroupId, planNodeId()); } auto blockingReason = rightSource_->next(&future_, &rightInput_); if (blockingReason != BlockingReason::kNotBlocked) { return nullptr; } if (rightInput_) { rightIndex_ = 0; } else { noMoreRightInput_ = true; } continue; } return nullptr; } } RowVectorPtr MergeJoin::doGetOutput() { // Check if we ran out of space in the output vector in the middle of the // match. if (leftMatch_ && leftMatch_->cursor) { VELOX_CHECK(rightMatch_ && rightMatch_->cursor); // Not all rows from the last match fit in the output. Continue producing // results from the current match. if (addToOutput()) { return std::move(output_); } } // There is no output-in-progress match, but there could be incomplete // match. if (leftMatch_) { VELOX_CHECK(rightMatch_); if (input_) { // Look for continuation of a match on the left and/or right sides. if (!findEndOfMatch(leftMatch_.value(), input_, leftKeys_)) { // Continue looking for the end of the match. input_ = nullptr; return nullptr; } if (leftMatch_->inputs.back() == input_) { index_ = leftMatch_->endIndex; } } else if (noMoreInput_) { leftMatch_->complete = true; } else { // Need more input. return nullptr; } if (rightInput_) { if (!findEndOfMatch(rightMatch_.value(), rightInput_, rightKeys_)) { // Continue looking for the end of the match. rightInput_ = nullptr; return nullptr; } if (rightMatch_->inputs.back() == rightInput_) { rightIndex_ = rightMatch_->endIndex; } } else if (noMoreRightInput_) { rightMatch_->complete = true; } else { // Need more input. return nullptr; } } // There is no output-in-progress match, but there can be a complete match // ready for output. if (leftMatch_) { VELOX_CHECK(leftMatch_->complete); VELOX_CHECK(rightMatch_ && rightMatch_->complete); if (addToOutput()) { return std::move(output_); } } if (!input_ || !rightInput_) { if (isLeftJoin(joinType_)) { if (input_ && noMoreRightInput_) { prepareOutput(); while (true) { if (outputSize_ == outputBatchSize_) { return std::move(output_); } addOutputRowForLeftJoin(); ++index_; if (index_ == input_->size()) { // Ran out of rows on the left side. input_ = nullptr; return nullptr; } } } if (noMoreInput_ && output_) { output_->resize(outputSize_); return std::move(output_); } } else { if (noMoreInput_ || noMoreRightInput_) { if (output_) { output_->resize(outputSize_); return std::move(output_); } input_ = nullptr; } } return nullptr; } // Look for a new match starting with index_ row on the left and rightIndex_ // row on the right. auto compareResult = compare(); for (;;) { // Catch up input_ with rightInput_. while (compareResult < 0) { if (isLeftJoin(joinType_)) { prepareOutput(); if (outputSize_ == outputBatchSize_) { return std::move(output_); } addOutputRowForLeftJoin(); } ++index_; if (index_ == input_->size()) { // Ran out of rows on the left side. input_ = nullptr; return nullptr; } compareResult = compare(); } // Catch up rightInput_ with input_. while (compareResult > 0) { ++rightIndex_; if (rightIndex_ == rightInput_->size()) { // Ran out of rows on the right side. rightInput_ = nullptr; return nullptr; } compareResult = compare(); } if (compareResult == 0) { // Found a match. Identify all rows on the left and right that have the // matching keys. vector_size_t endIndex = index_ + 1; while (endIndex < input_->size() && compareLeft(endIndex) == 0) { ++endIndex; } if (endIndex == input_->size()) { // Matches continue in subsequent input. Load all lazies. loadColumns(input_, *operatorCtx_->execCtx()); } leftMatch_ = Match{ {input_}, index_, endIndex, endIndex < input_->size(), std::nullopt}; vector_size_t endRightIndex = rightIndex_ + 1; while (endRightIndex < rightInput_->size() && compareRight(endRightIndex) == 0) { ++endRightIndex; } rightMatch_ = Match{ {rightInput_}, rightIndex_, endRightIndex, endRightIndex < rightInput_->size(), std::nullopt}; if (!leftMatch_->complete || !rightMatch_->complete) { if (!leftMatch_->complete) { // Need to continue looking for the end of match. input_ = nullptr; } if (!rightMatch_->complete) { // Need to continue looking for the end of match. rightInput_ = nullptr; } return nullptr; } index_ = endIndex; rightIndex_ = endRightIndex; if (addToOutput()) { return std::move(output_); } compareResult = compare(); } } VELOX_UNREACHABLE(); } RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { const auto numRows = output->size(); BufferPtr indices = allocateIndices(numRows, pool()); auto rawIndices = indices->asMutable<vector_size_t>(); vector_size_t numPassed = 0; if (leftJoinTracker_) { const auto& filterRows = leftJoinTracker_->matchingRows(numRows); if (!filterRows.hasSelections()) { // No matches in the output, no need to evaluate the filter. return output; } evaluateFilter(filterRows); // If all matches for a given left-side row fail the filter, add a row to // the output with nulls for the right-side columns. auto onMiss = [&](auto row) { rawIndices[numPassed++] = row; for (auto& projection : rightProjections_) { auto target = output->childAt(projection.outputChannel); target->setNull(row, true); } }; for (auto i = 0; i < numRows; ++i) { if (filterRows.isValid(i)) { const bool passed = !decodedFilterResult_.isNullAt(i) && decodedFilterResult_.valueAt<bool>(i); leftJoinTracker_->processFilterResult(i, passed, onMiss); if (passed) { rawIndices[numPassed++] = i; } } else { // This row doesn't have a match on the right side. Keep it // unconditionally. rawIndices[numPassed++] = i; } } if (!leftMatch_) { leftJoinTracker_->noMoreFilterResults(onMiss); } } else { filterRows_.resize(numRows); filterRows_.setAll(); evaluateFilter(filterRows_); for (auto i = 0; i < numRows; ++i) { if (!decodedFilterResult_.isNullAt(i) && decodedFilterResult_.valueAt<bool>(i)) { rawIndices[numPassed++] = i; } } } if (numPassed == 0) { // No rows passed. return nullptr; } if (numPassed == numRows) { // All rows passed. return output; } // Some, but not all rows passed. return wrap(numPassed, indices, output); } void MergeJoin::evaluateFilter(const SelectivityVector& rows) { EvalCtx evalCtx(operatorCtx_->execCtx(), filter_.get(), filterInput_.get()); filter_->eval(0, 1, true, rows, &evalCtx, &filterResult_); decodedFilterResult_.decode(*filterResult_[0], rows); } bool MergeJoin::isFinished() { return noMoreInput_ && input_ == nullptr; } } // namespace facebook::velox::exec