core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan()

in cpp/velox/substrait/SubstraitToVeloxPlan.cc [251:355]


core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::JoinRel& sJoin) {
  if (!sJoin.has_left()) {
    VELOX_FAIL("Left Rel is expected in JoinRel.");
  }
  if (!sJoin.has_right()) {
    VELOX_FAIL("Right Rel is expected in JoinRel.");
  }

  auto leftNode = toVeloxPlan(sJoin.left());
  auto rightNode = toVeloxPlan(sJoin.right());

  // Map join type.
  core::JoinType joinType;
  bool isNullAwareAntiJoin = false;
  switch (sJoin.type()) {
    case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER:
      joinType = core::JoinType::kInner;
      break;
    case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_OUTER:
      joinType = core::JoinType::kFull;
      break;
    case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT:
      joinType = core::JoinType::kLeft;
      break;
    case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT:
      joinType = core::JoinType::kRight;
      break;
    case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
      // Determine the semi join type based on extracted information.
      if (sJoin.has_advanced_extension() &&
          SubstraitParser::configSetInOptimization(sJoin.advanced_extension(), "isExistenceJoin=")) {
        joinType = core::JoinType::kLeftSemiProject;
      } else {
        joinType = core::JoinType::kLeftSemiFilter;
      }
      break;
    case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
      // Determine the semi join type based on extracted information.
      if (sJoin.has_advanced_extension() &&
          SubstraitParser::configSetInOptimization(sJoin.advanced_extension(), "isExistenceJoin=")) {
        joinType = core::JoinType::kRightSemiProject;
      } else {
        joinType = core::JoinType::kRightSemiFilter;
      }
      break;
    case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI: {
      // Determine the anti join type based on extracted information.
      if (sJoin.has_advanced_extension() &&
          SubstraitParser::configSetInOptimization(sJoin.advanced_extension(), "isNullAwareAntiJoin=")) {
        isNullAwareAntiJoin = true;
      }
      joinType = core::JoinType::kAnti;
      break;
    }
    default:
      VELOX_NYI("Unsupported Join type: {}", std::to_string(sJoin.type()));
  }

  // extract join keys from join expression
  std::vector<const ::substrait::Expression::FieldReference*> leftExprs, rightExprs;
  extractJoinKeys(sJoin.expression(), leftExprs, rightExprs);
  VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size());
  size_t numKeys = leftExprs.size();

  std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> leftKeys, rightKeys;
  leftKeys.reserve(numKeys);
  rightKeys.reserve(numKeys);
  auto inputRowType = getJoinInputType(leftNode, rightNode);
  for (size_t i = 0; i < numKeys; ++i) {
    leftKeys.emplace_back(exprConverter_->toVeloxExpr(*leftExprs[i], inputRowType));
    rightKeys.emplace_back(exprConverter_->toVeloxExpr(*rightExprs[i], inputRowType));
  }

  core::TypedExprPtr filter;
  if (sJoin.has_post_join_filter()) {
    filter = exprConverter_->toVeloxExpr(sJoin.post_join_filter(), inputRowType);
  }

  if (sJoin.has_advanced_extension() &&
      SubstraitParser::configSetInOptimization(sJoin.advanced_extension(), "isSMJ=")) {
    // Create MergeJoinNode node
    return std::make_shared<core::MergeJoinNode>(
        nextPlanNodeId(),
        joinType,
        leftKeys,
        rightKeys,
        filter,
        leftNode,
        rightNode,
        getJoinOutputType(leftNode, rightNode, joinType));

  } else {
    // Create HashJoinNode node
    return std::make_shared<core::HashJoinNode>(
        nextPlanNodeId(),
        joinType,
        isNullAwareAntiJoin,
        leftKeys,
        rightKeys,
        filter,
        leftNode,
        rightNode,
        getJoinOutputType(leftNode, rightNode, joinType));
  }
}