def appendRank[K1: ClassTag, K2: ClassTag, V: ClassTag]()

in s2jobs/src/main/scala/org/apache/s2graph/s2jobs/wal/udfs/WalLogUDF.scala [125:209]


  def appendRank[K1: ClassTag, K2: ClassTag, V: ClassTag](ds: Dataset[((K1, K2), V)],
                                                          numOfPartitions: Option[Int] = None,
                                                          samplePointsPerPartitionHint: Option[Int] = None)(implicit ordering: Ordering[(K1, K2)]) = {
    import org.apache.spark.RangePartitioner
    val rdd = ds.rdd

    val partitioner = new RangePartitioner(numOfPartitions.getOrElse(rdd.partitions.size),
      rdd,
      true,
      samplePointsPerPartitionHint = samplePointsPerPartitionHint.getOrElse(20)
    )

    val sorted = rdd.repartitionAndSortWithinPartitions(partitioner)

    def rank(idx: Int, iter: Iterator[((K1, K2), V)]) = {
      var curOffset = 1L
      var curK1 = null.asInstanceOf[K1]

      iter.map{ case ((key1, key2), value) =>
        //        println(s">>>[$idx] curK1: $curK1, curOffset: $curOffset")
        val newOffset = if (curK1 == key1) curOffset + 1L  else 1L
        curOffset = newOffset
        curK1 = key1
        (idx, newOffset, key1, key2, value)
      }
    }

    def getOffset(idx: Int, iter: Iterator[((K1, K2), V)]) = {
      val buffer = mutable.Map.empty[K1, (Int, Long)]
      if (!iter.hasNext) buffer.toIterator
      else {
        val ((k1, k2), v) = iter.next()
        var prevKey1: K1 = k1
        var size = 1L
        iter.foreach { case ((k1, k2), v) =>
          if (prevKey1 != k1) {
            buffer += prevKey1 -> (idx, size)
            prevKey1 = k1
            size = 0L
          }
          size += 1L
        }
        if (size > 0) buffer += prevKey1 -> (idx, size)
        buffer.iterator
      }
    }

    val partRanks = sorted.mapPartitionsWithIndex(rank)
    val _offsets = sorted.mapPartitionsWithIndex(getOffset)
    val offsets = _offsets.groupBy(_._1).flatMap { case (k1, partitionWithSize) =>
      val ls = partitionWithSize.toSeq.map(_._2).sortBy(_._1)
      var sum = ls.head._2
      val lss = ls.tail.map { case (partition, size) =>
        val x = (partition, sum)
        sum += size
        x
      }
      lss.map { case (partition, offset) =>
        (k1, partition) -> offset
      }
    }.collect()

    println(offsets)

    val offsetsBCast = ds.sparkSession.sparkContext.broadcast(offsets)

    def adjust(iter: Iterator[(Int, Long, K1, K2, V)], startOffsets: Map[(K1, Int), Long]) = {
      iter.map { case (partition, rankInPartition, key1, key2, value) =>
        val startOffset = startOffsets.getOrElse((key1, partition), 0L)
        val rank = startOffset + rankInPartition

        (partition, rankInPartition, rank, (key1, key2), value)
      }
    }

    val withRanks = partRanks
      .mapPartitions { iter =>
        val startOffsets = offsetsBCast.value.toMap
        adjust(iter, startOffsets)
      }.map { case (_, _, rank, (key1, key2), value) =>
      (rank, (key1, key2), value)
    }

    withRanks
  }