in spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala [290:438]
def prepareJVMShuffleDependency(
rdd: RDD[InternalRow],
outputAttributes: Seq[Attribute],
newPartitioning: Partitioning,
serializer: Serializer,
writeMetrics: Map[String, SQLMetric]): ShuffleDependency[Int, InternalRow, InternalRow] = {
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
new PartitionIdPassthrough(n)
case RangePartitioning(sortingExpressions, numPartitions) =>
// Extract only fields used for sorting to avoid collecting large fields that does not
// affect sorting result when deciding partition bounds in RangePartitioner
val rddForSampling = rdd.mapPartitionsInternal { iter =>
val projection =
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
val mutablePair = new MutablePair[InternalRow, Null]()
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
iter.map(row => mutablePair.update(projection(row).copy(), null))
}
// Construct ordering on extracted sort key.
val orderingAttributes = sortingExpressions.zipWithIndex.map { case (ord, i) =>
ord.copy(child = BoundReference(i, ord.dataType, ord.nullable))
}
implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes)
new RangePartitioner(
numPartitions,
rddForSampling,
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition => new ConstantPartitioner
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match {
case RoundRobinPartitioning(numPartitions) =>
// Distributes elements evenly across output partitions, starting from a random partition.
// nextInt(numPartitions) implementation has a special case when bound is a power of 2,
// which is basically taking several highest bits from the initial seed, with only a
// minimal scrambling. Due to deterministic seed, using the generator only once,
// and lack of scrambling, the position values for power-of-two numPartitions always
// end up being almost the same regardless of the index. substantially scrambling the
// seed by hashing will help. Refer to SPARK-21782 for more details.
val partitionId = TaskContext.get().partitionId()
var position = new XORShiftRandom(partitionId).nextInt(numPartitions)
(_: InternalRow) => {
// The HashPartitioner will handle the `mod` by the number of partitions
position += 1
position
}
case h: HashPartitioning =>
val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case RangePartitioning(sortingExpressions, _) =>
val projection =
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
}
val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
newPartitioning.numPartitions > 1
val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
// [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic,
// otherwise a retry task may output different rows and thus lead to data loss.
//
// Currently we following the most straight-forward way that perform a local sort before
// partitioning.
//
// Note that we don't perform local sort if the new partitioning has only 1 partition, under
// that case all output rows go to the same partition.
val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) {
rdd.mapPartitionsInternal { iter =>
val recordComparatorSupplier = new Supplier[RecordComparator] {
override def get: RecordComparator = new RecordBinaryComparator()
}
// The comparator for comparing row hashcode, which should always be Integer.
val prefixComparator = PrefixComparators.LONG
// The prefix computer generates row hashcode as the prefix, so we may decrease the
// probability that the prefixes are equal when input rows choose column values from a
// limited range.
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
override def computePrefix(
row: InternalRow): UnsafeExternalRowSorter.PrefixComputer.Prefix = {
// The hashcode generated from the binary form of a [[UnsafeRow]] should not be null.
result.isNull = false
result.value = row.hashCode()
result
}
}
val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
fromAttributes(outputAttributes),
recordComparatorSupplier,
prefixComparator,
prefixComputer,
pageSize,
// We are comparing binary here, which does not support radix sort.
// See more details in SPARK-28699.
false)
sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
}
} else {
rdd
}
// round-robin function is order sensitive if we don't sort the input.
val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
if (CometShuffleExchangeExec.needToCopyObjectsBeforeShuffle(part)) {
newRdd.mapPartitionsWithIndexInternal(
(_, iter) => {
val getPartitionKey = getPartitionKeyExtractor()
iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
},
isOrderSensitive = isOrderSensitive)
} else {
newRdd.mapPartitionsWithIndexInternal(
(_, iter) => {
val getPartitionKey = getPartitionKeyExtractor()
val mutablePair = new MutablePair[Int, InternalRow]()
iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
},
isOrderSensitive = isOrderSensitive)
}
}
// Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds
// are in the form of (partitionId, row) and every partitionId is in the expected range
// [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough.
val dependency =
new CometShuffleDependency[Int, InternalRow, InternalRow](
rddWithPartitionIds,
new PartitionIdPassthrough(part.numPartitions),
serializer,
shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
shuffleType = CometColumnarShuffle,
schema = Some(fromAttributes(outputAttributes)),
decodeTime = writeMetrics("decode_time"))
dependency
}