in spark-connector/common/src/main/scala/org/apache/spark/sql/odps/execution/exchange/OdpsShuffleExchangeExec.scala [196:347]
def prepareShuffleDependency(
rdd: RDD[InternalRow],
outputAttributes: Seq[Attribute],
newPartitioning: Partitioning,
serializer: Serializer,
writeMetrics: Map[String, SQLMetric])
: ShuffleDependency[Int, InternalRow, InternalRow] = {
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
new Partitioner {
override def numPartitions: Int = n
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}
case OdpsHashPartitioning(_, n) =>
new Partitioner {
override def numPartitions: Int = n
// For OdpsHashPartitioning, the partitioning key is already a valid partition ID,
// as we use `OdpsHashPartitioning.partitionIdExpression` to produce partitioning key.
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}
case RangePartitioning(sortingExpressions, numPartitions) =>
// Extract only fields used for sorting to avoid collecting large fields that does not
// affect sorting result when deciding partition bounds in RangePartitioner
val rddForSampling = rdd.mapPartitionsInternal { iter =>
val projection =
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
val mutablePair = new MutablePair[InternalRow, Null]()
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
iter.map(row => mutablePair.update(projection(row).copy(), null))
}
// Construct ordering on extracted sort key.
val orderingAttributes = sortingExpressions.zipWithIndex.map { case (ord, i) =>
ord.copy(child = BoundReference(i, ord.dataType, ord.nullable))
}
implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes)
new RangePartitioner(
numPartitions,
rddForSampling,
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition =>
new Partitioner {
override def numPartitions: Int = 1
override def getPartition(key: Any): Int = 0
}
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match {
case RoundRobinPartitioning(numPartitions) =>
// Distributes elements evenly across output partitions, starting from a random partition.
var position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions)
(row: InternalRow) => {
// The HashPartitioner will handle the `mod` by the number of partitions
position += 1
position
}
case h: HashPartitioning =>
val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case h: OdpsHashPartitioning =>
val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case RangePartitioning(sortingExpressions, _) =>
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
}
val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
newPartitioning.numPartitions > 1
val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
// [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic,
// otherwise a retry task may output different rows and thus lead to data loss.
//
// Currently we following the most straight-forward way that perform a local sort before
// partitioning.
//
// Note that we don't perform local sort if the new partitioning has only 1 partition, under
// that case all output rows go to the same partition.
val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) {
rdd.mapPartitionsInternal { iter =>
val recordComparatorSupplier = new Supplier[RecordComparator] {
override def get: RecordComparator = new RecordBinaryComparator()
}
// The comparator for comparing row hashcode, which should always be Integer.
val prefixComparator = PrefixComparators.LONG
// The prefix computer generates row hashcode as the prefix, so we may decrease the
// probability that the prefixes are equal when input rows choose column values from a
// limited range.
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
override def computePrefix(row: InternalRow):
UnsafeExternalRowSorter.PrefixComputer.Prefix = {
// The hashcode generated from the binary form of a [[UnsafeRow]] should not be null.
result.isNull = false
result.value = row.hashCode()
result
}
}
val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
StructType.fromAttributes(outputAttributes),
recordComparatorSupplier,
prefixComparator,
prefixComputer,
pageSize,
// We are comparing binary here, which does not support radix sort.
// See more details in SPARK-28699.
false)
sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
}
} else {
rdd
}
// round-robin function is order sensitive if we don't sort the input.
val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
if (needToCopyObjectsBeforeShuffle(part)) {
newRdd.mapPartitionsWithIndexInternal((_, iter) => {
val getPartitionKey = getPartitionKeyExtractor()
iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
}, isOrderSensitive = isOrderSensitive)
} else {
newRdd.mapPartitionsWithIndexInternal((_, iter) => {
val getPartitionKey = getPartitionKeyExtractor()
val mutablePair = new MutablePair[Int, InternalRow]()
iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
}, isOrderSensitive = isOrderSensitive)
}
}
// Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds
// are in the form of (partitionId, row) and every partitionId is in the expected range
// [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough.
val dependency =
new ShuffleDependency[Int, InternalRow, InternalRow](
rddWithPartitionIds,
new PartitionIdPassthrough(part.numPartitions),
serializer,
shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics))
dependency
}