def apply()

in sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala [464:1113]


  def apply(
      innerPlan: LogicalPlan,
      outerPlan: LogicalPlan,
      handleCountBug: Boolean = false): (LogicalPlan, Seq[Expression]) = {
    val outputPlanInputAttrs = outerPlan.inputSet

    // The return type of the recursion.
    // The first parameter is a new logical plan with correlation eliminated.
    // The second parameter is a list of join conditions with the outer query.
    // The third parameter is a mapping between the outer references and equivalent
    // expressions from the inner query that is used to replace outer references.
    type ReturnType = (LogicalPlan, Seq[Expression], AttributeMap[Attribute])

    // Decorrelate the input plan.
    // parentOuterReferences: a set of parent outer references. As we recurse down we collect the
    // set of outer references that are part of the Domain, and use it to construct the DomainJoins
    // and join conditions.
    // aggregated: a boolean flag indicating whether the result of the plan will be aggregated
    // (or used as an input for a window function)
    // underSetOp: a boolean flag indicating whether a set operator (e.g. UNION) is a parent of the
    // inner plan.
    //
    // Steps:
    // 1. Recursively collects outer references from the inner query until it reaches a node
    //    that does not contain correlated value.
    // 2. Inserts an optional [[DomainJoin]] node to indicate whether a domain (inner) join is
    //    needed between the outer query and the specific sub-tree of the inner query.
    // 3. Returns a list of join conditions with the outer query and a mapping between outer
    //    references with references inside the inner query. The parent nodes need to preserve
    //    the references inside the join conditions and substitute all outer references using
    //    the mapping.
    def decorrelate(
        plan: LogicalPlan,
        parentOuterReferences: AttributeSet,
        aggregated: Boolean = false,
        underSetOp: Boolean = false
    ): ReturnType = {
      val isCorrelated = hasOuterReferences(plan)
      if (!isCorrelated) {
        // We have reached a plan without correlation to the outer plan.
        if (parentOuterReferences.isEmpty) {
          // If there is no outer references from the parent nodes, it means all outer
          // attributes can be substituted by attributes from the inner plan. So no
          // domain join is needed.
          (plan, Nil, AttributeMap.empty[Attribute])
        } else {
          // Build the domain join with the parent outer references.
          val attributes = parentOuterReferences.toSeq
          val domains = attributes.map(_.newInstance())
          // A placeholder to be rewritten into domain join.
          val domainJoin = DomainJoin(domains, plan)
          val outerReferenceMap = Utils.toMap(attributes, domains)
          // Build join conditions between domain attributes and outer references.
          // EqualNullSafe is used to make sure null key can be joined together. Note
          // outer referenced attributes can be changed during the outer query optimization.
          // The equality conditions will also serve as an attribute mapping between new
          // outer references and domain attributes when rewriting the domain joins.
          // E.g. if the attribute a is changed to a1, the join condition a' <=> outer(a)
          // will become a' <=> a1, and we can construct the aliases based on the condition:
          // DomainJoin [a']        Join Inner
          // +- InnerQuery     =>   :- InnerQuery
          //                        +- Aggregate [a1] [a1 AS a']
          //                           +- OuterQuery
          val conditions = outerReferenceMap.map {
            case (o, a) =>
              val cond = EqualNullSafe(a, OuterReference(o))
              // SPARK-40615: Certain data types (e.g. MapType) do not support ordering, so
              // the EqualNullSafe join condition can become unresolved.
              if (!cond.resolved) {
                if (!RowOrdering.isOrderable(a.dataType)) {
                  throw QueryCompilationErrors.unsupportedCorrelatedReferenceDataTypeError(
                    o, a.dataType, plan.origin)
                } else {
                  throw SparkException.internalError(s"Unable to decorrelate subquery: " +
                    s"join condition '${cond.sql}' cannot be resolved.")
                }
              }
              cond
          }
          (domainJoin, conditions.toSeq, AttributeMap(outerReferenceMap))
        }
      } else {
        plan match {
          case Filter(condition, child) =>
            val conditions = splitConjunctivePredicates(condition)
            val (correlated, uncorrelated) = conditions.partition(containsOuter)
            // Find outer references that can be substituted by attributes from the inner
            // query using the equality predicates.
            // If we are under a set op, we never use the predicates directly to substitute outer
            // refs for now. Future improvement: use the predicates directly if they exist in all
            // children of the set op.
            val equivalences =
              if (underSetOp) AttributeMap.empty[Attribute]
              else collectEquivalentOuterReferences(correlated)
            // Correlated predicates can be removed from the Filter's condition and used as
            // join conditions with the outer query. However, if the results of the sub-tree
            // is aggregated, only certain correlated equality predicates can be used, because
            // the references in the join conditions need to be preserved in both the grouping
            // and aggregate expressions of an Aggregate, which may change the semantics of the
            // plan and lead to incorrect results. Here is an example:
            // Relations:
            //   t1(a, b): [(1, 1)]
            //   t2(c, d): [(1, 1), (2, 0)]
            //
            // Query:
            //   SELECT * FROM t1 WHERE a = (SELECT MAX(c) FROM t2 WHERE b >= d)
            //
            // Subquery plan transformation if correlated predicates are used as join conditions:
            //   Aggregate [max(c)]               Aggregate [d] [max(c), d]
            //   +- Filter (outer(b) >= d)   =>   +- Relation [c, d]
            //      +- Relation [c, d]
            //
            // Plan after rewrite:
            //   Project [a, b]                                   -- [(1, 1)]
            //   +- Join LeftOuter (b >= d AND a = max(c))
            //      :- Relation [a, b]
            //      +- Aggregate [d] [max(c), d]                  -- [(1, 1), (2, 0)]
            //         +- Relation [c, d]
            //
            // The result of the original query should be an empty set but the transformed
            // query will output an incorrect result of (1, 1). The correct transformation
            // with domain join is illustrated below:
            //   Aggregate [max(c)]               Aggregate [b'] [max(c), b']
            //   +- Filter (outer(b) >= d)   =>   +- Filter (b' >= d)
            //      +- Relation [c, d]               +- DomainJoin [b']
            //                                          +- Relation [c, d]
            // Plan after rewrite:
            //   Project [a, b]
            //   +- Join LeftOuter (b <=> b' AND a = max(c))  -- []
            //      :- Relation [a, b]
            //      +- Aggregate [b'] [max(c), b']            -- [(2, 1)]
            //         +- Join Inner (b' >= d)                -- [(1, 1, 1), (2, 0, 1)] (DomainJoin)
            //            :- Relation [c, d]
            //            +- Aggregate [b] [b AS b']          -- [(1)] (Domain)
            //               +- Relation [a, b]
            if (aggregated || underSetOp) {
              // Split the correlated predicates into predicates that can and cannot be directly
              // used as join conditions with the outer query depending on whether they can
              // be pulled up over an Aggregate without changing the semantics of the plan.
              // If we are under a set op, we never use the predicates directly for now. Future
              // improvement: use the predicates directly if they exist in all children of the set
              // op.
              val (equalityCond, predicates) =
                if (underSetOp) (Seq.empty[Expression], correlated)
                else correlated.partition(canPullUpOverAgg)
              val outerReferences = collectOuterReferences(predicates)
              val newOuterReferences =
                parentOuterReferences ++ outerReferences -- equivalences.keySet
              val (newChild, joinCond, outerReferenceMap) =
                decorrelate(child, newOuterReferences, aggregated, underSetOp)
              // Add the outer references mapping collected from the equality conditions.
              val newOuterReferenceMap = outerReferenceMap ++ equivalences
              // Replace all outer references in the non-equality predicates.
              val newCorrelated = replaceOuterReferences(predicates, newOuterReferenceMap)
              // The new filter condition is the original filter condition with correlated
              // equality predicates removed.
              val newFilterCond = newCorrelated ++ uncorrelated
              val newFilter = newFilterCond match {
                case Nil => newChild
                case conditions => Filter(conditions.reduce(And), newChild)
              }
              // Equality predicates are used as join conditions with the outer query.
              val newJoinCond = joinCond ++ equalityCond
              (newFilter, newJoinCond, newOuterReferenceMap)
            } else {
              // Results of this sub-tree is not aggregated, so all correlated predicates
              // can be directly used as outer query join conditions.
              val newOuterReferences = parentOuterReferences -- equivalences.keySet
              val (newChild, joinCond, outerReferenceMap) =
                decorrelate(child, newOuterReferences, aggregated, underSetOp)
              // Add the outer references mapping collected from the equality conditions.
              val newOuterReferenceMap = outerReferenceMap ++ equivalences
              val newFilter = uncorrelated match {
                case Nil => newChild
                case conditions => Filter(conditions.reduce(And), newChild)
              }
              val newJoinCond = joinCond ++ correlated
              (newFilter, newJoinCond, newOuterReferenceMap)
            }

          case Project(projectList, child) =>
            val outerReferences = collectOuterReferences(projectList)
            val newOuterReferences = parentOuterReferences ++ outerReferences
            val (newChild, joinCond, outerReferenceMap) =
              decorrelate(child, newOuterReferences, aggregated, underSetOp)
            // Replace all outer references in the original project list and keep the output
            // attributes unchanged.
            val newProjectList = replaceOuterInNamedExpressions(projectList, outerReferenceMap)
            // Preserve required domain attributes in the join condition by adding the missing
            // references to the new project list.
            val referencesToAdd = missingReferences(newProjectList, joinCond)
            val newProject = Project(newProjectList ++ referencesToAdd, newChild)
            (newProject, joinCond, outerReferenceMap)

          case Offset(offset, input) =>
            // OFFSET K is decorrelated by skipping top k rows per every domain value
            // via a row_number() window function, which is similar to limit decorrelation.
            // Limit and Offset situation are handled by limit branch as offset is the child
            // of limit in that case. This branch is for the case where there's no limit operator
            // above offset.
            val (child, ordering) = input match {
              case Sort(order, _, child, _) => (child, order)
              case _ => (input, Seq())
            }
            val (newChild, joinCond, outerReferenceMap) =
              decorrelate(input, parentOuterReferences, aggregated = true, underSetOp)
            val collectedChildOuterReferences = collectOuterReferencesInPlanTree(child)
            // Add outer references to the PARTITION BY clause
            val partitionFields = collectedChildOuterReferences
              .filter(outerReferenceMap.contains(_))
              .map(outerReferenceMap(_)).toSeq
            if (partitionFields.isEmpty) {
              // Underlying subquery has no predicates connecting inner and outer query.
              // In this case, offset can be computed over the inner query directly.
              (Offset(offset, newChild), joinCond, outerReferenceMap)
            } else {
              val orderByFields = replaceOuterReferences(ordering, outerReferenceMap)

              val rowNumber = WindowExpression(RowNumber(),
                WindowSpecDefinition(partitionFields, orderByFields,
                  SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)))
              val rowNumberAlias = Alias(rowNumber, "rn")()
              // Window function computes row_number() when partitioning by correlated references,
              // and projects all the other fields from the input.
              val window = Window(Seq(rowNumberAlias),
                partitionFields, orderByFields, newChild)
              val filter = Filter(GreaterThan(rowNumberAlias.toAttribute, offset), window)
              val project = Project(newChild.output, filter)
              (project, joinCond, outerReferenceMap)
            }

          case Limit(limit, input) =>
            // LIMIT K (with potential ORDER BY or OFFSET) is decorrelated by computing
            // K rows per every domain value via a row_number() window function.
            // For example, for a subquery
            // (SELECT T2.a FROM T2 WHERE T2.b = OuterReference(x) ORDER BY T2.c LIMIT 3 OFFSET 2)
            // -- we need to get top 3 values of T2.a (ordering by T2.c) for every value of x with
            // an offset 2.
            // Following our general decorrelation procedure, 'x' is then replaced by T2.b, so the
            // subquery is decorrelated as:
            // SELECT * FROM (
            //   SELECT T2.a, row_number() OVER (PARTITION BY T2.b ORDER BY T2.c) AS rn FROM T2)
            // WHERE rn > 2 AND rn <= 2+3
            val (child, ordering, offsetExpr) = input match {
              case Sort(order, _, child, _) => (child, order, Literal(0))
              case Offset(offsetExpr, offsetChild@(Sort(order, _, child, _))) =>
                (child, order, offsetExpr)
              case Offset(offsetExpr, child) =>
                (child, Seq(), offsetExpr)
              case _ => (input, Seq(), Literal(0))
            }
            val (newChild, joinCond, outerReferenceMap) =
              decorrelate(child, parentOuterReferences, aggregated = true, underSetOp)
            val collectedChildOuterReferences = collectOuterReferencesInPlanTree(child)
            // Add outer references to the PARTITION BY clause
            val partitionFields = collectedChildOuterReferences
              .filter(outerReferenceMap.contains(_))
              .map(outerReferenceMap(_)).toSeq
            if (partitionFields.isEmpty) {
              // Underlying subquery has no predicates connecting inner and outer query.
              // In this case, limit can be computed over the inner query directly.
              offsetExpr match {
                case IntegerLiteral(0) => (Limit(limit, newChild), joinCond, outerReferenceMap)
                case _ => (Limit(limit, Offset(offsetExpr, newChild)), joinCond, outerReferenceMap)
              }
            } else {
              val orderByFields = replaceOuterReferences(ordering, outerReferenceMap)

              val rowNumber = WindowExpression(RowNumber(),
                WindowSpecDefinition(partitionFields, orderByFields,
                  SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)))
              val rowNumberAlias = Alias(rowNumber, "rn")()
              // Window function computes row_number() when partitioning by correlated references,
              // and projects all the other fields from the input.
              val window = Window(Seq(rowNumberAlias),
                partitionFields, orderByFields, newChild)
              val filter = offsetExpr match {
                case IntegerLiteral(0) =>
                  // If there is no offset, we can directly use the row number to filter the rows.
                  Filter(LessThanOrEqual(rowNumberAlias.toAttribute, limit), window)
                case _ =>
                  Filter(
                    And(
                      GreaterThan(rowNumberAlias.toAttribute, offsetExpr),
                      LessThanOrEqual(rowNumberAlias.toAttribute, Add(offsetExpr, limit))
                    ),
                    window
                  )
              }
              val project = Project(newChild.output, filter)
              (project, joinCond, outerReferenceMap)
            }

          case w @ Window(projectList, partitionSpec, orderSpec, child, hint) =>
            val outerReferences = collectOuterReferences(w.expressions)
            assert(outerReferences.isEmpty, s"Correlated column is not allowed in window " +
              s"function: $w")
            val newOuterReferences = parentOuterReferences ++ outerReferences
            val (newChild, joinCond, outerReferenceMap) =
              decorrelate(child, newOuterReferences, aggregated = true, underSetOp)
            // For now these are no-op, as we don't allow correlated references in the window
            // function itself.
            val newProjectList = replaceOuterReferences(projectList, outerReferenceMap)
            val newPartitionSpec = replaceOuterReferences(partitionSpec, outerReferenceMap)
            val newOrderSpec = replaceOuterReferences(orderSpec, outerReferenceMap)
            val referencesToAdd = missingReferences(newProjectList, joinCond)

            val newWindow = Window(newProjectList ++ referencesToAdd,
              partitionSpec = newPartitionSpec ++ referencesToAdd,
              orderSpec = newOrderSpec, newChild, hint)
            (newWindow, joinCond, outerReferenceMap)

          case a @ Aggregate(groupingExpressions, aggregateExpressions, child, _) =>
            val outerReferences = collectOuterReferences(a.expressions)
            val newOuterReferences = parentOuterReferences ++ outerReferences
            val (newChild, joinCond, outerReferenceMap) =
              decorrelate(child, newOuterReferences, aggregated = true, underSetOp)
            // Replace all outer references in grouping and aggregate expressions, and keep
            // the output attributes unchanged.
            val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap)
            val newAggExpr = replaceOuterInNamedExpressions(aggregateExpressions, outerReferenceMap)
            // Add all required domain attributes to both grouping and aggregate expressions.
            val referencesToAdd = missingReferences(newAggExpr, joinCond)
            val newAggregate = a.copy(
              groupingExpressions = newGroupingExpr ++ referencesToAdd,
              aggregateExpressions = newAggExpr ++ referencesToAdd,
              child = newChild)

            // Preserving domain attributes over an Aggregate with an empty grouping expression
            // is subject to the "COUNT bug" that can lead to wrong answer:
            //
            // Suppose the original query is:
            //   SELECT a, (SELECT COUNT(*) cnt FROM t2 WHERE t1.a = t2.c) FROM t1
            //
            // Decorrelated plan:
            //   Project [a, scalar-subquery [a = c]]
            //   :  +- Aggregate [c] [count(*) AS cnt, c]
            //   :     +- Relation [c, d]
            //   +- Relation [a, b]
            //
            // After rewrite:
            //   Project [a, cnt]
            //   +- Join LeftOuter (a = c)
            //      :- Relation [a, b]
            //      +- Aggregate [c] [count(*) AS cnt, c]
            //         +- Relation [c, d]
            //
            //     T1            T2          T2' (GROUP BY c)
            // +---+---+     +---+---+     +---+-----+
            // | a | b |     | c | d |     | c | cnt |
            // +---+---+     +---+---+     +---+-----+
            // | 0 | 1 |     | 0 | 2 |     | 0 | 2   |
            // | 1 | 2 |     | 0 | 3 |     +---+-----+
            // +---+---+     +---+---+
            //
            // T1 nested loop join T2     T1 left outer join T2'
            // on (a = c):                on (a = c):
            // +---+-----+                +---+-----++
            // | a | cnt |                | a | cnt  |
            // +---+-----+                +---+------+
            // | 0 | 2   |                | 0 | 2    |
            // | 1 | 0   | <--- correct   | 1 | null | <--- wrong result
            // +---+-----+                +---+------+
            //
            // If an aggregate is subject to the COUNT bug:
            // 1) add a column `true AS alwaysTrue` to the result of the aggregate
            // 2) insert a left outer domain join between the outer query and this aggregate
            // 3) rewrite the original aggregate's output column using the default value of the
            //    aggregate function and the alwaysTrue column.
            //
            // For example, T1 left outer join T2' with `alwaysTrue` marker:
            // +---+------+------------+--------------------------------+
            // | c | cnt  | alwaysTrue | if(isnull(alwaysTrue), 0, cnt) |
            // +---+------+------------+--------------------------------+
            // | 0 | 2    | true       | 2                              |
            // | 0 | null | null       | 0                              |  <--- correct result
            // +---+------+------------+--------------------------------+
            if (groupingExpressions.isEmpty && handleCountBug) {
              // Evaluate the aggregate expressions with zero tuples.
              val resultMap = RewriteCorrelatedScalarSubquery.evalAggregateOnZeroTups(newAggregate)
              val alwaysTrue = Alias(Literal.TrueLiteral, "alwaysTrue")()
              val alwaysTrueRef = alwaysTrue.toAttribute.withNullability(true)
              val expressions = ArrayBuffer.empty[NamedExpression]
              // Create new aliases for aggregate expressions that have non-null default
              // values and reconstruct the output with the `alwaysTrue` marker.
              val projectList = newAggregate.aggregateExpressions.map { a =>
                resultMap.get(a.exprId) match {
                  // Aggregate expression is not subject to the count bug.
                  case Some(Literal(null, _)) | None =>
                    expressions += a
                    // The attribute is nullable since it is from the right-hand side of a
                    // left outer join.
                    a.toAttribute.withNullability(true)
                  case Some(default) =>
                    assert(a.isInstanceOf[Alias], s"Cannot have non-aliased expression $a in " +
                      s"aggregate that evaluates to non-null value with zero tuples.")
                    val newAttr = a.newInstance()
                    val ref = newAttr.toAttribute.withNullability(true)
                    expressions += newAttr
                    Alias(If(IsNull(alwaysTrueRef), default, ref), a.name)(a.exprId)
                }
              }
              // Insert a placeholder left outer domain join between the outer query and
              // and aggregate node and use the current collected join conditions as the
              // left outer join condition.
              //
              // Original subquery:
              //   Aggregate [count(1) AS cnt]
              //   +- Filter (a = outer(c))
              //      +- Relation [a, b]
              //
              // After decorrelation and before COUNT bug handling:
              //   Aggregate [a] [count(1) AS cnt, a]
              //   +- Relation [a, b]
              //
              // joinCond with the outer query: (a = outer(c))
              //
              // Handle the COUNT bug:
              //   Project [if(isnull(alwaysTrue), 0, cnt') AS cnt, c']
              //   +- DomainJoin [c'] LeftOuter (a = c')
              //      +- Aggregate [a] [count(1) AS cnt', a, true AS alwaysTrue]
              //         +- Relation [a, b]
              //
              // New joinCond with the outer query: (c' <=> outer(c)), and the DomainJoin
              // will be written as:
              //   Project [if(isnull(alwaysTrue), 0, cnt') AS cnt, c']
              //   +- Join LeftOuter (a = c')
              //      :- Aggregate [c] [c AS c']
              //      :  +- OuterQuery [c, d]
              //      +- Aggregate [a] [count(1) AS cnt', a, true AS alwaysTrue]
              //         +- Relation [a, b]
              //
              val agg = newAggregate.copy(aggregateExpressions = expressions.toSeq :+ alwaysTrue)
              // Find all outer references that are used in the join conditions.
              val outerAttrs = collectOuterReferences(joinCond).toSeq
              // Create new instance of the outer attributes as if they are generated inside
              // the subquery by a left outer join with the outer query. Use new instance here
              // to avoid conflicting join attributes with the inner query.
              val domainAttrs = outerAttrs.map(_.newInstance())
              val mapping = AttributeMap(outerAttrs.zip(domainAttrs))
              // Use the current join conditions returned from the recursive call as the join
              // conditions for the left outer join. All outer references in the join
              // conditions are replaced by the newly created domain attributes.
              val condition = replaceOuterReferences(joinCond, mapping).reduceOption(And)
              val domainJoin = DomainJoin(domainAttrs, agg, LeftOuter, condition)
              // Original domain attributes preserved through Aggregate are no longer needed.
              val newProjectList = projectList.filter(!referencesToAdd.contains(_))
              val project = Project(newProjectList ++ domainAttrs, domainJoin)
              val newJoinCond = outerAttrs.zip(domainAttrs).map { case (outer, inner) =>
                EqualNullSafe(inner, OuterReference(outer))
              }
              (project, newJoinCond, mapping)
            } else {
              (newAggregate, joinCond, outerReferenceMap)
            }

          case d: Distinct =>
            val (newChild, joinCond, outerReferenceMap) =
              decorrelate(d.child, parentOuterReferences, aggregated = true, underSetOp)
            (d.copy(child = newChild), joinCond, outerReferenceMap)

          case j @ Join(left, right, joinType, condition, _) =>
            // Given 'condition', computes the tuple of
            // (correlated, uncorrelated, equalityCond, predicates, equivalences).
            // 'correlated' and 'uncorrelated' are the conjuncts with (resp. without)
            // outer (correlated) references. Furthermore, correlated conjuncts are split
            // into 'equalityCond' (those that are equalities) and all rest ('predicates').
            // 'equivalences' track equivalent attributes given 'equalityCond'.
            // The split is only performed if 'shouldDecorrelatePredicates' is true.
            // The input parameter 'isInnerJoin' is set to true for INNER joins and helps
            // determine whether some predicates can be lifted up from the join (this is only
            // valid for inner joins).
            // Example: For a 'condition' A = outer(X) AND B > outer(Y) AND C = D, the output
            // would be:
            // correlated = (A = outer(X), B > outer(Y))
            // uncorrelated = (C = D)
            // equalityCond = (A = outer(X))
            // predicates = (B > outer(Y))
            // equivalences: (A -> outer(X))
            def splitCorrelatedPredicate(
                condition: Option[Expression],
                isInnerJoin: Boolean,
                shouldDecorrelatePredicates: Boolean):
            (Seq[Expression], Seq[Expression], Seq[Expression],
              Seq[Expression], AttributeMap[Attribute]) = {
              // Similar to Filters above, we split the join condition (if present) into correlated
              // and uncorrelated predicates, and separately handle joins under set and aggregation
              // operations.
              if (shouldDecorrelatePredicates) {
                val conditions =
                  if (condition.isDefined) splitConjunctivePredicates(condition.get)
                  else Seq.empty[Expression]
                val (correlated, uncorrelated) = conditions.partition(containsOuter)
                var equivalences =
                  if (underSetOp) AttributeMap.empty[Attribute]
                  else collectEquivalentOuterReferences(correlated)
                var (equalityCond, predicates) =
                  if (underSetOp) (Seq.empty[Expression], correlated)
                  else correlated.partition(canPullUpOverAgg)
                // Fully preserve the join predicate for non-inner joins.
                if (!isInnerJoin) {
                  predicates = correlated
                  equalityCond = Seq.empty[Expression]
                  equivalences = AttributeMap.empty[Attribute]
                }
                (correlated, uncorrelated, equalityCond, predicates, equivalences)
              } else {
                (Seq.empty[Expression],
                  if (condition.isEmpty) Seq.empty[Expression] else Seq(condition.get),
                  Seq.empty[Expression],
                  Seq.empty[Expression],
                  AttributeMap.empty[Attribute])
              }
            }

            val shouldDecorrelatePredicates =
              SQLConf.get.getConf(SQLConf.DECORRELATE_JOIN_PREDICATE_ENABLED)
            if (!shouldDecorrelatePredicates) {
              val outerReferences = collectOuterReferences(j.expressions)
              // Join condition containing outer references is not supported.
              assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j")
            }
            val (correlated, uncorrelated, equalityCond, predicates, equivalences) =
              splitCorrelatedPredicate(condition, joinType == Inner, shouldDecorrelatePredicates)
            val outerReferences = collectOuterReferences(j.expressions) ++
              collectOuterReferences(predicates)
            val newOuterReferences =
              parentOuterReferences ++ outerReferences -- equivalences.keySet
            var shouldPushToLeft = joinType match {
              case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true
              case _ => hasOuterReferences(left)
            }
            val shouldPushToRight = joinType match {
              case RightOuter | FullOuter => true
              case _ => hasOuterReferences(right)
            }
            if (shouldDecorrelatePredicates && !shouldPushToLeft && !shouldPushToRight
              && !predicates.isEmpty) {
              // Neither left nor right children of the join have correlations, but the join
              // predicate does, and the correlations can not be replaced via equivalences.
              // Introduce a domain join on the left side of the join
              // (chosen arbitrarily) to provide values for the correlated attribute reference.
              shouldPushToLeft = true;
            }
            val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) {
              decorrelate(left, newOuterReferences, aggregated, underSetOp)
            } else {
              (left, Nil, AttributeMap.empty[Attribute])
            }
            val (newRight, rightJoinCond, rightOuterReferenceMap) = if (shouldPushToRight) {
              decorrelate(right, newOuterReferences, aggregated, underSetOp)
            } else {
              (right, Nil, AttributeMap.empty[Attribute])
            }
            val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap ++
              equivalences
            val newCorrelated =
              if (shouldDecorrelatePredicates) {
                replaceOuterReferences(correlated, newOuterReferenceMap)
              } else Seq.empty[Expression]
            val newJoinCond = leftJoinCond ++ rightJoinCond ++ equalityCond
            // If we push the dependent join to both sides, we can augment the join condition
            // such that both sides are matched on the domain attributes. For example,
            // - Left Map: {outer(c1) = c1}
            // - Right Map: {outer(c1) = 10 - c1}
            // Then the join condition can be augmented with (c1 <=> 10 - c1).
            val augmentedConditions = leftOuterReferenceMap.flatMap {
              case (outer, inner) => rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _))
            }
            val newCondition = (newCorrelated ++ uncorrelated
              ++ augmentedConditions).reduceOption(And)
            val newJoin = j.copy(left = newLeft, right = newRight, condition = newCondition)
            (newJoin, newJoinCond, newOuterReferenceMap)

          case s @ (_ : Union | _: SetOperation) =>
            // Set ops are decorrelated by pushing the domain join into each child. For details see
            // https://docs.google.com/document/d/11b9ClCF2jYGU7vU2suOT7LRswYkg6tZ8_6xJbvxfh2I/edit

            // First collect outer references from all children - these must all be added to the
            // Domain (otherwise we’d be unioning together inner values corresponding to different
            // outer values).
            //
            // As an example, this inner subquery:
            //   select c from t1 where t1.a = t_outer.a
            //   UNION ALL
            //   select c from t2 where t2.b = t_outer.b
            // has columns a, b in the Domain and is rewritten to:
            //   select c, t_outer.a, t_outer.b from t1 join t_outer where t1.a = t_outer.a
            //   UNION ALL
            //   select c, t_outer.a, t_outer.b from t2 join t_outer where t2.b = t_outer.b
            val collectedChildOuterReferences = collectOuterReferencesInPlanTree(s)
            val newOuterReferences = AttributeSet(
              parentOuterReferences ++ collectedChildOuterReferences)

            val childDecorrelateResults =
              s.children.map { child =>
                val (decorrelatedChild, newJoinCond, newOuterReferenceMap) =
                  decorrelate(child, newOuterReferences, aggregated, underSetOp = true)
                // Create a Project to ensure that the domain attributes are added to the same
                // positions in each child of the set op. If we don't explicitly construct this
                // Project, they could get added at the beginning or the end of the output columns
                // depending on the child plan.
                // The inner expressions for the domain are the values of newOuterReferenceMap.
                val domainProjections =
                  if (SQLConf.get.getConf(
                    SQLConf.DECORRELATE_UNION_OR_SET_OP_UNDER_LIMIT_ENABLED
                  )) {
                    newOuterReferences.map(newOuterReferenceMap(_))
                  } else {
                    collectedChildOuterReferences.map(newOuterReferenceMap(_))
                  }
                val newChild = Project(child.output ++ domainProjections, decorrelatedChild)
                (newChild, newJoinCond, newOuterReferenceMap)
              }

            val newChildren = childDecorrelateResults.map(_._1)
            // Need to use the join cond and outer ref map from the first child, because attribute
            // names are from the first child
            val newJoinCond = childDecorrelateResults.head._2
            val newOuterReferenceMap = AttributeMap(childDecorrelateResults.head._3)
            (s.withNewChildren(newChildren), newJoinCond, newOuterReferenceMap)

          case g: Generate if g.requiredChildOutput.isEmpty =>
            // Generate with non-empty required child output cannot host
            // outer reference. It is blocked by CheckAnalysis.
            val outerReferences = collectOuterReferences(g.expressions)
            val newOuterReferences = parentOuterReferences ++ outerReferences
            val (newChild, joinCond, outerReferenceMap) =
              decorrelate(g.child, newOuterReferences, aggregated)
            // Replace all outer references in the original generator expression.
            val newGenerator = replaceOuterReference(g.generator, outerReferenceMap)
            val newGenerate = g.copy(generator = newGenerator, child = newChild)
            (newGenerate, joinCond, outerReferenceMap)

          case u: UnaryNode =>
            val outerReferences = collectOuterReferences(u.expressions)
            assert(outerReferences.isEmpty, s"Correlated column is not allowed in $u")
            val (newChild, joinCond, outerReferenceMap) =
              decorrelate(u.child, parentOuterReferences, aggregated, underSetOp)
            (u.withNewChildren(newChild :: Nil), joinCond, outerReferenceMap)

          case o =>
            throw QueryExecutionErrors.decorrelateInnerQueryThroughPlanUnsupportedError(o)
        }
      }
    }
    val (newChild, joinCond, _) = decorrelate(BooleanSimplification(innerPlan), AttributeSet.empty)
    val (plan, conditions) = deduplicate(newChild, joinCond, outputPlanInputAttrs)
    (plan, stripOuterReferences(conditions))
  }