in cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc [991:1064]
bool SubstraitToVeloxPlanValidator::validate(const ::substrait::JoinRel& joinRel) {
if (joinRel.has_left() && !validate(joinRel.left())) {
LOG_VALIDATION_MSG("Validation fails for join left input.");
return false;
}
if (joinRel.has_right() && !validate(joinRel.right())) {
LOG_VALIDATION_MSG("Validation fails for join right input.");
return false;
}
if (joinRel.has_advanced_extension() &&
SubstraitParser::configSetInOptimization(joinRel.advanced_extension(), "isSMJ=")) {
switch (joinRel.type()) {
case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_OUTER:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_ANTI:
break;
default:
LOG_VALIDATION_MSG("Sort merge join type is not supported: " + std::to_string(joinRel.type()));
return false;
}
}
switch (joinRel.type()) {
case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_OUTER:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_ANTI:
break;
default:
LOG_VALIDATION_MSG("Join type is not supported: " + std::to_string(joinRel.type()));
return false;
}
// Validate input types.
if (!joinRel.has_advanced_extension()) {
LOG_VALIDATION_MSG("Input types are expected in JoinRel.");
return false;
}
const auto& extension = joinRel.advanced_extension();
TypePtr inputRowType;
std::vector<TypePtr> types;
if (!parseVeloxType(extension, inputRowType) || !flattenSingleLevel(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in JoinRel.");
return false;
}
int32_t inputPlanNodeId = 0;
std::vector<std::string> names;
names.reserve(types.size());
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx));
}
auto rowType = std::make_shared<RowType>(std::move(names), std::move(types));
if (joinRel.has_expression()) {
std::vector<const ::substrait::Expression::FieldReference*> leftExprs, rightExprs;
planConverter_.extractJoinKeys(joinRel.expression(), leftExprs, rightExprs);
}
if (joinRel.has_post_join_filter()) {
auto expression = exprConverter_->toVeloxExpr(joinRel.post_join_filter(), rowType);
exec::ExprSet exprSet({std::move(expression)}, execCtx_.get());
}
return true;
}