in cpp/velox/substrait/SubstraitToVeloxPlan.cc [251:355]
core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::JoinRel& sJoin) {
if (!sJoin.has_left()) {
VELOX_FAIL("Left Rel is expected in JoinRel.");
}
if (!sJoin.has_right()) {
VELOX_FAIL("Right Rel is expected in JoinRel.");
}
auto leftNode = toVeloxPlan(sJoin.left());
auto rightNode = toVeloxPlan(sJoin.right());
// Map join type.
core::JoinType joinType;
bool isNullAwareAntiJoin = false;
switch (sJoin.type()) {
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER:
joinType = core::JoinType::kInner;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_OUTER:
joinType = core::JoinType::kFull;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT:
joinType = core::JoinType::kLeft;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT:
joinType = core::JoinType::kRight;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
// Determine the semi join type based on extracted information.
if (sJoin.has_advanced_extension() &&
SubstraitParser::configSetInOptimization(sJoin.advanced_extension(), "isExistenceJoin=")) {
joinType = core::JoinType::kLeftSemiProject;
} else {
joinType = core::JoinType::kLeftSemiFilter;
}
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
// Determine the semi join type based on extracted information.
if (sJoin.has_advanced_extension() &&
SubstraitParser::configSetInOptimization(sJoin.advanced_extension(), "isExistenceJoin=")) {
joinType = core::JoinType::kRightSemiProject;
} else {
joinType = core::JoinType::kRightSemiFilter;
}
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI: {
// Determine the anti join type based on extracted information.
if (sJoin.has_advanced_extension() &&
SubstraitParser::configSetInOptimization(sJoin.advanced_extension(), "isNullAwareAntiJoin=")) {
isNullAwareAntiJoin = true;
}
joinType = core::JoinType::kAnti;
break;
}
default:
VELOX_NYI("Unsupported Join type: {}", std::to_string(sJoin.type()));
}
// extract join keys from join expression
std::vector<const ::substrait::Expression::FieldReference*> leftExprs, rightExprs;
extractJoinKeys(sJoin.expression(), leftExprs, rightExprs);
VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size());
size_t numKeys = leftExprs.size();
std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> leftKeys, rightKeys;
leftKeys.reserve(numKeys);
rightKeys.reserve(numKeys);
auto inputRowType = getJoinInputType(leftNode, rightNode);
for (size_t i = 0; i < numKeys; ++i) {
leftKeys.emplace_back(exprConverter_->toVeloxExpr(*leftExprs[i], inputRowType));
rightKeys.emplace_back(exprConverter_->toVeloxExpr(*rightExprs[i], inputRowType));
}
core::TypedExprPtr filter;
if (sJoin.has_post_join_filter()) {
filter = exprConverter_->toVeloxExpr(sJoin.post_join_filter(), inputRowType);
}
if (sJoin.has_advanced_extension() &&
SubstraitParser::configSetInOptimization(sJoin.advanced_extension(), "isSMJ=")) {
// Create MergeJoinNode node
return std::make_shared<core::MergeJoinNode>(
nextPlanNodeId(),
joinType,
leftKeys,
rightKeys,
filter,
leftNode,
rightNode,
getJoinOutputType(leftNode, rightNode, joinType));
} else {
// Create HashJoinNode node
return std::make_shared<core::HashJoinNode>(
nextPlanNodeId(),
joinType,
isNullAwareAntiJoin,
leftKeys,
rightKeys,
filter,
leftNode,
rightNode,
getJoinOutputType(leftNode, rightNode, joinType));
}
}