in extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiEnsureRequirements.scala [223:384]
private def checkKeyGroupCompatible(
left: SparkPlan,
right: SparkPlan,
joinType: JoinType,
requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = {
assert(requiredChildDistribution.length == 2)
var newLeft = left
var newRight = right
val specs = Seq(left, right).zip(requiredChildDistribution).map { case (p, d) =>
if (!d.isInstanceOf[ClusteredDistribution]) return None
val cd = d.asInstanceOf[ClusteredDistribution]
val specOpt = createKeyGroupedShuffleSpec(p.outputPartitioning, cd)
if (specOpt.isEmpty) return None
specOpt.get
}
val leftSpec = specs.head
val rightSpec = specs(1)
var isCompatible = false
if (!conf.v2BucketingPushPartValuesEnabled) {
isCompatible = leftSpec.isCompatibleWith(rightSpec)
} else {
logInfo("Pushing common partition values for storage-partitioned join")
isCompatible = leftSpec.areKeysCompatible(rightSpec)
// Partition expressions are compatible. Regardless of whether partition values
// match from both sides of children, we can calculate a superset of partition values and
// push-down to respective data sources so they can adjust their output partitioning by
// filling missing partition keys with empty partitions. As result, we can still avoid
// shuffle.
//
// For instance, if two sides of a join have partition expressions
// `day(a)` and `day(b)` respectively
// (the join query could be `SELECT ... FROM t1 JOIN t2 on t1.a = t2.b`), but
// with different partition values:
// `day(a)`: [0, 1]
// `day(b)`: [1, 2, 3]
// Following the case 2 above, we don't have to shuffle both sides, but instead can
// just push the common set of partition values: `[0, 1, 2, 3]` down to the two data
// sources.
if (isCompatible) {
val leftPartValues = leftSpec.partitioning.partitionValues
val rightPartValues = rightSpec.partitioning.partitionValues
logInfo(
s"""
|Left side # of partitions: ${leftPartValues.size}
|Right side # of partitions: ${rightPartValues.size}
|""".stripMargin)
// As partition keys are compatible, we can pick either left or right as partition
// expressions
val partitionExprs = leftSpec.partitioning.expressions
var mergedPartValues = InternalRowComparableWrapper
.mergePartitions(leftSpec.partitioning, rightSpec.partitioning, partitionExprs)
.map(v => (v, 1))
logInfo(s"After merging, there are ${mergedPartValues.size} partitions")
var replicateLeftSide = false
var replicateRightSide = false
var applyPartialClustering = false
// This means we allow partitions that are not clustered on their values,
// that is, multiple partitions with the same partition value. In the
// following, we calculate how many partitions that each distinct partition
// value has, and pushdown the information to scans, so they can adjust their
// final input partitions respectively.
if (conf.v2BucketingPartiallyClusteredDistributionEnabled) {
logInfo("Calculating partially clustered distribution for " +
"storage-partitioned join")
// Similar to `OptimizeSkewedJoin`, we need to check join type and decide
// whether partially clustered distribution can be applied. For instance, the
// optimization cannot be applied to a left outer join, where the left hand
// side is chosen as the side to replicate partitions according to stats.
// Otherwise, query result could be incorrect.
val canReplicateLeft = canReplicateLeftSide(joinType)
val canReplicateRight = canReplicateRightSide(joinType)
if (!canReplicateLeft && !canReplicateRight) {
logInfo("Skipping partially clustered distribution as it cannot be applied for " +
s"join type '$joinType'")
} else {
val leftLink = left.logicalLink
val rightLink = right.logicalLink
replicateLeftSide =
if (leftLink.isDefined && rightLink.isDefined &&
leftLink.get.stats.sizeInBytes > 1 &&
rightLink.get.stats.sizeInBytes > 1) {
logInfo(
s"""
|Using plan statistics to determine which side of join to fully
|cluster partition values:
|Left side size (in bytes): ${leftLink.get.stats.sizeInBytes}
|Right side size (in bytes): ${rightLink.get.stats.sizeInBytes}
|""".stripMargin)
leftLink.get.stats.sizeInBytes < rightLink.get.stats.sizeInBytes
} else {
// As a simple heuristic, we pick the side with fewer number of partitions
// to apply the grouping & replication of partitions
logInfo("Using number of partitions to determine which side of join " +
"to fully cluster partition values")
leftPartValues.size < rightPartValues.size
}
replicateRightSide = !replicateLeftSide
// Similar to skewed join, we need to check the join type to see whether replication
// of partitions can be applied. For instance, replication should not be allowed for
// the left-hand side of a right outer join.
if (replicateLeftSide && !canReplicateLeft) {
logInfo("Left-hand side is picked but cannot be applied to join type " +
s"'$joinType'. Skipping partially clustered distribution.")
replicateLeftSide = false
} else if (replicateRightSide && !canReplicateRight) {
logInfo("Right-hand side is picked but cannot be applied to join type " +
s"'$joinType'. Skipping partially clustered distribution.")
replicateRightSide = false
} else {
val partValues = if (replicateLeftSide) rightPartValues else leftPartValues
val numExpectedPartitions = partValues
.map(InternalRowComparableWrapper(_, partitionExprs))
.groupBy(identity)
.mapValues(_.size)
mergedPartValues = mergedPartValues.map { case (partVal, numParts) =>
(
partVal,
numExpectedPartitions.getOrElse(
InternalRowComparableWrapper(partVal, partitionExprs),
numParts))
}
logInfo("After applying partially clustered distribution, there are " +
s"${mergedPartValues.map(_._2).sum} partitions.")
applyPartialClustering = true
}
}
}
// Now we need to push-down the common partition key to the scan in each child
newLeft = populatePartitionValues(
left,
mergedPartValues,
applyPartialClustering,
replicateLeftSide)
newRight = populatePartitionValues(
right,
mergedPartValues,
applyPartialClustering,
replicateRightSide)
}
}
if (isCompatible) Some(Seq(newLeft, newRight)) else None
}