in s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/udfs/WalLogUDF.scala [125:209]
def appendRank[K1: ClassTag, K2: ClassTag, V: ClassTag](ds: Dataset[((K1, K2), V)],
numOfPartitions: Option[Int] = None,
samplePointsPerPartitionHint: Option[Int] = None)(implicit ordering: Ordering[(K1, K2)]) = {
import org.apache.spark.RangePartitioner
val rdd = ds.rdd
val partitioner = new RangePartitioner(numOfPartitions.getOrElse(rdd.partitions.size),
rdd,
true,
samplePointsPerPartitionHint = samplePointsPerPartitionHint.getOrElse(20)
)
val sorted = rdd.repartitionAndSortWithinPartitions(partitioner)
def rank(idx: Int, iter: Iterator[((K1, K2), V)]) = {
var curOffset = 1L
var curK1 = null.asInstanceOf[K1]
iter.map{ case ((key1, key2), value) =>
// println(s">>>[$idx] curK1: $curK1, curOffset: $curOffset")
val newOffset = if (curK1 == key1) curOffset + 1L else 1L
curOffset = newOffset
curK1 = key1
(idx, newOffset, key1, key2, value)
}
}
def getOffset(idx: Int, iter: Iterator[((K1, K2), V)]) = {
val buffer = mutable.Map.empty[K1, (Int, Long)]
if (!iter.hasNext) buffer.toIterator
else {
val ((k1, k2), v) = iter.next()
var prevKey1: K1 = k1
var size = 1L
iter.foreach { case ((k1, k2), v) =>
if (prevKey1 != k1) {
buffer += prevKey1 -> (idx, size)
prevKey1 = k1
size = 0L
}
size += 1L
}
if (size > 0) buffer += prevKey1 -> (idx, size)
buffer.iterator
}
}
val partRanks = sorted.mapPartitionsWithIndex(rank)
val _offsets = sorted.mapPartitionsWithIndex(getOffset)
val offsets = _offsets.groupBy(_._1).flatMap { case (k1, partitionWithSize) =>
val ls = partitionWithSize.toSeq.map(_._2).sortBy(_._1)
var sum = ls.head._2
val lss = ls.tail.map { case (partition, size) =>
val x = (partition, sum)
sum += size
x
}
lss.map { case (partition, offset) =>
(k1, partition) -> offset
}
}.collect()
println(offsets)
val offsetsBCast = ds.sparkSession.sparkContext.broadcast(offsets)
def adjust(iter: Iterator[(Int, Long, K1, K2, V)], startOffsets: Map[(K1, Int), Long]) = {
iter.map { case (partition, rankInPartition, key1, key2, value) =>
val startOffset = startOffsets.getOrElse((key1, partition), 0L)
val rank = startOffset + rankInPartition
(partition, rankInPartition, rank, (key1, key2), value)
}
}
val withRanks = partRanks
.mapPartitions { iter =>
val startOffsets = offsetsBCast.value.toMap
adjust(iter, startOffsets)
}.map { case (_, _, rank, (key1, key2), value) =>
(rank, (key1, key2), value)
}
withRanks
}