def genShuffleDependency()

in backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala [84:216]


  def genShuffleDependency(
      rdd: RDD[ColumnarBatch],
      outputAttributes: Seq[Attribute],
      newPartitioning: Partitioning,
      serializer: Serializer,
      writeMetrics: Map[String, SQLMetric],
      metrics: Map[String, SQLMetric],
      isSort: Boolean): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
    metrics("numPartitions").set(newPartitioning.numPartitions)
    val executionId = rdd.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
    SQLMetrics.postDriverMetricUpdates(
      rdd.sparkContext,
      executionId,
      metrics("numPartitions") :: Nil)
    // scalastyle:on argcount
    // only used for fallback range partitioning
    val rangePartitioner: Option[Partitioner] = newPartitioning match {
      case RangePartitioning(sortingExpressions, numPartitions) =>
        // Extract only fields used for sorting to avoid collecting large fields that does not
        // affect sorting result when deciding partition bounds in RangePartitioner
        val rddForSampling = rdd.mapPartitionsInternal {
          iter =>
            // Internally, RangePartitioner runs a job on the RDD that samples keys to compute
            // partition bounds. To get accurate samples, we need to copy the mutable keys.
            iter.flatMap(
              batch => {
                val rows = convertColumnarToRow(batch)
                val projection =
                  UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
                val mutablePair = new MutablePair[InternalRow, Null]()
                rows.map(row => mutablePair.update(projection(row).copy(), null))
              })
        }
        // Construct ordering on extracted sort key.
        val orderingAttributes = sortingExpressions.zipWithIndex.map {
          case (ord, i) =>
            ord.copy(child = BoundReference(i, ord.dataType, ord.nullable))
        }
        implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes)
        val part = new RangePartitioner(
          numPartitions,
          rddForSampling,
          ascending = true,
          samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
        Some(part)
      case _ => None
    }

    // only used for fallback range partitioning
    def computeAndAddPartitionId(
        cbIter: Iterator[ColumnarBatch],
        partitionKeyExtractor: InternalRow => Any): Iterator[(Int, ColumnarBatch)] = {
      Iterators
        .wrap(
          cbIter
            .filter(cb => cb.numRows != 0 && cb.numCols != 0)
            .map {
              cb =>
                val pidVec = ArrowWritableColumnVector
                  .allocateColumns(cb.numRows, new StructType().add("pid", IntegerType))
                  .head
                convertColumnarToRow(cb).zipWithIndex.foreach {
                  case (row, i) =>
                    val pid = rangePartitioner.get.getPartition(partitionKeyExtractor(row))
                    pidVec.putInt(i, pid)
                }
                val pidBatch = VeloxColumnarBatches.toVeloxBatch(
                  ColumnarBatches.offload(
                    ArrowBufferAllocators.contextInstance(),
                    new ColumnarBatch(Array[ColumnVector](pidVec), cb.numRows)))
                val newBatch = VeloxColumnarBatches.compose(pidBatch, cb)
                // Composed batch already hold pidBatch's shared ref, so close is safe.
                ColumnarBatches.forceClose(pidBatch)
                (0, newBatch)
            })
        .recyclePayload(p => ColumnarBatches.forceClose(p._2)) // FIXME why force close?
        .create()
    }

    val nativePartitioning: NativePartitioning = newPartitioning match {
      case SinglePartition =>
        new NativePartitioning(GlutenShuffleUtils.SinglePartitioningShortName, 1)
      case RoundRobinPartitioning(n) =>
        new NativePartitioning(GlutenShuffleUtils.RoundRobinPartitioningShortName, n)
      case HashPartitioning(exprs, n) =>
        new NativePartitioning(GlutenShuffleUtils.HashPartitioningShortName, n)
      // range partitioning fall back to row-based partition id computation
      case RangePartitioning(orders, n) =>
        new NativePartitioning(GlutenShuffleUtils.RangePartitioningShortName, n)
    }

    val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
      newPartitioning.numPartitions > 1

    // RDD passed to ShuffleDependency should be the form of key-value pairs.
    // ColumnarShuffleWriter will compute ids from ColumnarBatch on native side
    // other than read the "key" part.
    // Thus in Columnar Shuffle we never use the "key" part.
    val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition

    val rddWithDummyKey: RDD[Product2[Int, ColumnarBatch]] = newPartitioning match {
      case RangePartitioning(sortingExpressions, _) =>
        rdd.mapPartitionsWithIndexInternal(
          (_, cbIter) => {
            val partitionKeyExtractor: InternalRow => Any = {
              val projection =
                UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
              row => projection(row)
            }
            val newIter = computeAndAddPartitionId(cbIter, partitionKeyExtractor)
            newIter
          },
          isOrderSensitive = isOrderSensitive
        )
      case _ =>
        rdd.mapPartitionsWithIndexInternal(
          (_, cbIter) => cbIter.map(cb => (0, cb)),
          isOrderSensitive = isOrderSensitive)
    }

    val dependency =
      new ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch](
        rddWithDummyKey,
        new PartitionIdPassThrough(newPartitioning.numPartitions),
        serializer,
        shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
        nativePartitioning = nativePartitioning,
        metrics = metrics,
        isSort = isSort
      )

    dependency
  }