override def apply()

in extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/DynamicShufflePartitions.scala [35:96]


  override def apply(plan: SparkPlan): SparkPlan = {
    if (!conf.getConf(DYNAMIC_SHUFFLE_PARTITIONS) || !conf.getConf(ADAPTIVE_EXECUTION_ENABLED)) {
      plan
    } else {
      val maxDynamicShufflePartitions = conf.getConf(DYNAMIC_SHUFFLE_PARTITIONS_MAX_NUM)

      def collectScanSizes(plan: SparkPlan): Seq[Long] = plan match {
        case FileSourceScanExec(relation, _, _, _, _, _, _, _, _) =>
          Seq(relation.location.sizeInBytes)
        case t: HiveTableScanExec =>
          t.relation.prunedPartitions match {
            case Some(partitions) => Seq(partitions.flatMap(_.stats).map(_.sizeInBytes.toLong).sum)
            case None => Seq(t.relation.computeStats().sizeInBytes.toLong)
                .filter(_ != conf.defaultSizeInBytes)
          }
        case stage: ShuffleQueryStageExec if stage.isMaterialized && stage.mapStats.isDefined =>
          Seq(stage.mapStats.get.bytesByPartitionId.sum)
        case p =>
          p.children.flatMap(collectScanSizes)
      }

      val scanSizes = collectScanSizes(plan)
      if (scanSizes.isEmpty) {
        return plan
      }

      val targetSize = conf.getConf(ADVISORY_PARTITION_SIZE_IN_BYTES)
      val targetShufflePartitions = Math.min(
        Math.max(scanSizes.sum / targetSize + 1, conf.numShufflePartitions).toInt,
        maxDynamicShufflePartitions)

      val newPlan = plan transformUp {
        case exchange @ ShuffleExchangeExec(outputPartitioning, _, shuffleOrigin, _)
            if shuffleOrigin != REPARTITION_BY_NUM =>
          val newOutPartitioning = outputPartitioning match {
            case RoundRobinPartitioning(numPartitions)
                if targetShufflePartitions != numPartitions =>
              Some(RoundRobinPartitioning(targetShufflePartitions))
            case HashPartitioning(expressions, numPartitions)
                if targetShufflePartitions != numPartitions =>
              Some(HashPartitioning(expressions, targetShufflePartitions))
            case RangePartitioning(ordering, numPartitions)
                if targetShufflePartitions != numPartitions =>
              Some(RangePartitioning(ordering, targetShufflePartitions))
            case _ => None
          }
          if (newOutPartitioning.isDefined) {
            exchange.copy(outputPartitioning = newOutPartitioning.get)
          } else {
            exchange
          }
      }

      if (ValidateRequirements.validate(newPlan)) {
        newPlan
      } else {
        logInfo("DynamicShufflePartitions rule generated an invalid plan. " +
          "Falling back to the original plan.")
        plan
      }
    }
  }