private def checkKeyGroupCompatible()

in extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiEnsureRequirements.scala [223:384]


  private def checkKeyGroupCompatible(
      left: SparkPlan,
      right: SparkPlan,
      joinType: JoinType,
      requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = {
    assert(requiredChildDistribution.length == 2)

    var newLeft = left
    var newRight = right

    val specs = Seq(left, right).zip(requiredChildDistribution).map { case (p, d) =>
      if (!d.isInstanceOf[ClusteredDistribution]) return None
      val cd = d.asInstanceOf[ClusteredDistribution]
      val specOpt = createKeyGroupedShuffleSpec(p.outputPartitioning, cd)
      if (specOpt.isEmpty) return None
      specOpt.get
    }

    val leftSpec = specs.head
    val rightSpec = specs(1)

    var isCompatible = false
    if (!conf.v2BucketingPushPartValuesEnabled) {
      isCompatible = leftSpec.isCompatibleWith(rightSpec)
    } else {
      logInfo("Pushing common partition values for storage-partitioned join")
      isCompatible = leftSpec.areKeysCompatible(rightSpec)

      // Partition expressions are compatible. Regardless of whether partition values
      // match from both sides of children, we can calculate a superset of partition values and
      // push-down to respective data sources so they can adjust their output partitioning by
      // filling missing partition keys with empty partitions. As result, we can still avoid
      // shuffle.
      //
      // For instance, if two sides of a join have partition expressions
      // `day(a)` and `day(b)` respectively
      // (the join query could be `SELECT ... FROM t1 JOIN t2 on t1.a = t2.b`), but
      // with different partition values:
      //   `day(a)`: [0, 1]
      //   `day(b)`: [1, 2, 3]
      // Following the case 2 above, we don't have to shuffle both sides, but instead can
      // just push the common set of partition values: `[0, 1, 2, 3]` down to the two data
      // sources.
      if (isCompatible) {
        val leftPartValues = leftSpec.partitioning.partitionValues
        val rightPartValues = rightSpec.partitioning.partitionValues

        logInfo(
          s"""
             |Left side # of partitions: ${leftPartValues.size}
             |Right side # of partitions: ${rightPartValues.size}
             |""".stripMargin)

        // As partition keys are compatible, we can pick either left or right as partition
        // expressions
        val partitionExprs = leftSpec.partitioning.expressions

        var mergedPartValues = InternalRowComparableWrapper
          .mergePartitions(leftSpec.partitioning, rightSpec.partitioning, partitionExprs)
          .map(v => (v, 1))

        logInfo(s"After merging, there are ${mergedPartValues.size} partitions")

        var replicateLeftSide = false
        var replicateRightSide = false
        var applyPartialClustering = false

        // This means we allow partitions that are not clustered on their values,
        // that is, multiple partitions with the same partition value. In the
        // following, we calculate how many partitions that each distinct partition
        // value has, and pushdown the information to scans, so they can adjust their
        // final input partitions respectively.
        if (conf.v2BucketingPartiallyClusteredDistributionEnabled) {
          logInfo("Calculating partially clustered distribution for " +
            "storage-partitioned join")

          // Similar to `OptimizeSkewedJoin`, we need to check join type and decide
          // whether partially clustered distribution can be applied. For instance, the
          // optimization cannot be applied to a left outer join, where the left hand
          // side is chosen as the side to replicate partitions according to stats.
          // Otherwise, query result could be incorrect.
          val canReplicateLeft = canReplicateLeftSide(joinType)
          val canReplicateRight = canReplicateRightSide(joinType)

          if (!canReplicateLeft && !canReplicateRight) {
            logInfo("Skipping partially clustered distribution as it cannot be applied for " +
              s"join type '$joinType'")
          } else {
            val leftLink = left.logicalLink
            val rightLink = right.logicalLink

            replicateLeftSide =
              if (leftLink.isDefined && rightLink.isDefined &&
                leftLink.get.stats.sizeInBytes > 1 &&
                rightLink.get.stats.sizeInBytes > 1) {
                logInfo(
                  s"""
                     |Using plan statistics to determine which side of join to fully
                     |cluster partition values:
                     |Left side size (in bytes): ${leftLink.get.stats.sizeInBytes}
                     |Right side size (in bytes): ${rightLink.get.stats.sizeInBytes}
                     |""".stripMargin)
                leftLink.get.stats.sizeInBytes < rightLink.get.stats.sizeInBytes
              } else {
                // As a simple heuristic, we pick the side with fewer number of partitions
                // to apply the grouping & replication of partitions
                logInfo("Using number of partitions to determine which side of join " +
                  "to fully cluster partition values")
                leftPartValues.size < rightPartValues.size
              }

            replicateRightSide = !replicateLeftSide

            // Similar to skewed join, we need to check the join type to see whether replication
            // of partitions can be applied. For instance, replication should not be allowed for
            // the left-hand side of a right outer join.
            if (replicateLeftSide && !canReplicateLeft) {
              logInfo("Left-hand side is picked but cannot be applied to join type " +
                s"'$joinType'. Skipping partially clustered distribution.")
              replicateLeftSide = false
            } else if (replicateRightSide && !canReplicateRight) {
              logInfo("Right-hand side is picked but cannot be applied to join type " +
                s"'$joinType'. Skipping partially clustered distribution.")
              replicateRightSide = false
            } else {
              val partValues = if (replicateLeftSide) rightPartValues else leftPartValues
              val numExpectedPartitions = partValues
                .map(InternalRowComparableWrapper(_, partitionExprs))
                .groupBy(identity)
                .mapValues(_.size)

              mergedPartValues = mergedPartValues.map { case (partVal, numParts) =>
                (
                  partVal,
                  numExpectedPartitions.getOrElse(
                    InternalRowComparableWrapper(partVal, partitionExprs),
                    numParts))
              }

              logInfo("After applying partially clustered distribution, there are " +
                s"${mergedPartValues.map(_._2).sum} partitions.")
              applyPartialClustering = true
            }
          }
        }

        // Now we need to push-down the common partition key to the scan in each child
        newLeft = populatePartitionValues(
          left,
          mergedPartValues,
          applyPartialClustering,
          replicateLeftSide)
        newRight = populatePartitionValues(
          right,
          mergedPartValues,
          applyPartialClustering,
          replicateRightSide)
      }
    }

    if (isCompatible) Some(Seq(newLeft, newRight)) else None
  }