private def repartitionAndSort()

in maven-projects/spark/graphar/src/main/scala/org/apache/graphar/writer/EdgeWriter.scala [47:200]


  private def repartitionAndSort(
      spark: SparkSession,
      edgeDf: DataFrame,
      edgeInfo: EdgeInfo,
      adjListType: AdjListType.Value,
      vertexNumOfPrimaryVertexType: Long
  ): (DataFrame, ParSeq[(Int, DataFrame)], Array[Long], Map[Long, Int]) = {
    val edgeSchema = edgeDf.schema
    val colName = if (
      adjListType == AdjListType.ordered_by_source || adjListType == AdjListType.unordered_by_source
    ) GeneralParams.srcIndexCol
    else GeneralParams.dstIndexCol
    val colIndex = edgeSchema.fieldIndex(colName)
    val vertexChunkSize: Long = if (
      adjListType == AdjListType.ordered_by_source || adjListType == AdjListType.unordered_by_source
    ) edgeInfo.getSrc_chunk_size()
    else edgeInfo.getDst_chunk_size()
    val edgeChunkSize: Long = edgeInfo.getChunk_size()
    val vertexChunkNum: Int =
      ((vertexNumOfPrimaryVertexType + vertexChunkSize - 1) / vertexChunkSize).toInt // ceil

    // sort by primary key and generate continue edge id for edge records
    val sortedDfRDD = edgeDf.sort(colName).rdd
    sortedDfRDD.persist(GeneralParams.defaultStorageLevel)
    // generate continue edge id for every edge
    val partitionCounts = sortedDfRDD
      .mapPartitionsWithIndex(
        (i, ps) => Array((i, ps.size)).iterator,
        preservesPartitioning = true
      )
      .collectAsMap()
    val aggregatedPartitionCounts = SortedMap(partitionCounts.toSeq: _*)
      .foldLeft((0L, Map.empty[Int, Long])) { case ((total, map), (i, c)) =>
        (total + c, map + (i -> total))
      }
      ._2
    val broadcastedPartitionCounts =
      spark.sparkContext.broadcast(aggregatedPartitionCounts)
    val rddWithEid = sortedDfRDD.mapPartitionsWithIndex((i, ps) => {
      val start = broadcastedPartitionCounts.value(i)
      for { (row, j) <- ps.zipWithIndex } yield (start + j, row)
    })
    rddWithEid.persist(GeneralParams.defaultStorageLevel)

    // Construct partitioner for edge chunk
    // get edge num of every vertex chunk
    val edgeNumOfVertexChunks = sortedDfRDD
      .mapPartitions(iterator => {
        iterator.map(row =>
          (row(colIndex).asInstanceOf[Long] / vertexChunkSize, 1)
        )
      })
      .reduceByKey(_ + _)
      .collectAsMap()
    // Mapping: vertex_chunk_index -> edge num of the vertex chunk
    var edgeNumMutableMap =
      collection.mutable.Map(edgeNumOfVertexChunks.toSeq: _*)
    for (i <- 0L until vertexChunkNum.toLong) {
      if (!edgeNumMutableMap.contains(i)) {
        edgeNumMutableMap(i) = 0
      }
    }
    sortedDfRDD.unpersist() // unpersist the sortedDfRDD

    var eidBeginOfVertexChunks =
      new Array[Long](vertexChunkNum + 1) // eid begin of vertex chunks
    var aggEdgeChunkNumOfVertexChunks =
      new Array[Long](vertexChunkNum + 1) // edge chunk begin of vertex chunks
    var eid: Long = 0
    var edgeChunkIndex: Long = 0
    for (i <- 0 until vertexChunkNum) {
      eidBeginOfVertexChunks(i) = eid
      aggEdgeChunkNumOfVertexChunks(i) = edgeChunkIndex
      eid = eid + edgeNumMutableMap(i)
      edgeChunkIndex = edgeChunkIndex + (edgeNumMutableMap(
        i
      ) + edgeChunkSize - 1) / edgeChunkSize
    }
    eidBeginOfVertexChunks(vertexChunkNum) = eid
    aggEdgeChunkNumOfVertexChunks(vertexChunkNum) = edgeChunkIndex

    val partitionNum = edgeChunkIndex.toInt
    val partitioner = new EdgeChunkPartitioner(
      partitionNum,
      eidBeginOfVertexChunks,
      aggEdgeChunkNumOfVertexChunks,
      edgeChunkSize.toInt
    )

    // repartition edge DataFrame and sort within partitions
    val partitionRDD =
      rddWithEid.repartitionAndSortWithinPartitions(partitioner).values
    val partitionEdgeDf = spark.createDataFrame(partitionRDD, edgeSchema)
    rddWithEid.unpersist() // unpersist the rddWithEid
    partitionEdgeDf.persist(GeneralParams.defaultStorageLevel)

    // generate offset DataFrames
    if (
      adjListType == AdjListType.ordered_by_source || adjListType == AdjListType.ordered_by_dest
    ) {
      val edgeCountsByPrimaryKey = partitionRDD
        .mapPartitions(iterator => {
          iterator.map(row => (row(colIndex).asInstanceOf[Long], 1))
        })
        .reduceByKey(_ + _)
      edgeCountsByPrimaryKey.persist(GeneralParams.defaultStorageLevel)
      val offsetDfSchema = StructType(
        Seq(StructField(GeneralParams.offsetCol, IntegerType))
      )
      val offsetDfArray: ParSeq[(Int, DataFrame)] =
        (0 until vertexChunkNum).par.map { i =>
          {
            val filterRDD = edgeCountsByPrimaryKey
              .filter(v => v._1 / vertexChunkSize == i)
              .map { case (k, v) => (k - i * vertexChunkSize + 1, v) }
            val initRDD = spark.sparkContext
              .range(0L, vertexChunkSize + 1)
              .map(key => (key, 0))
            val unionRDD = spark.sparkContext
              .union(filterRDD, initRDD)
              .reduceByKey(_ + _)
              .sortByKey(numPartitions = 1)
            val offsetRDD = unionRDD
              .mapPartitionsWithIndex((i, ps) => {
                var sum = 0
                var preSum = 0
                for ((k, count) <- ps) yield {
                  preSum = sum
                  sum = sum + count
                  (k, count + preSum)
                }
              })
              .map { case (k, v) => Row(v) }
            val offsetChunk = spark.createDataFrame(offsetRDD, offsetDfSchema)
            offsetChunk.persist(GeneralParams.defaultStorageLevel)
            (i, offsetChunk)
          }
        }
      edgeCountsByPrimaryKey.unpersist() // unpersist the edgeCountsByPrimaryKey
      return (
        partitionEdgeDf,
        offsetDfArray,
        aggEdgeChunkNumOfVertexChunks,
        edgeNumMutableMap.toMap
      )
    }
    val offsetDfArray = ParSeq.empty[(Int, DataFrame)]
    return (
      partitionEdgeDf,
      offsetDfArray,
      aggEdgeChunkNumOfVertexChunks,
      edgeNumMutableMap.toMap
    )
  }