def internalWrite()

in backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala [132:249]


  def internalWrite(records: Iterator[Product2[K, V]]): Unit = {
    if (!records.hasNext) {
      handleEmptyInput()
      return
    }

    val dataTmp = Utils.tempFileWith(shuffleBlockResolver.getDataFile(dep.shuffleId, mapId))

    while (records.hasNext) {
      val cb = records.next()._2.asInstanceOf[ColumnarBatch]
      if (cb.numRows == 0 || cb.numCols == 0) {
        logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols")
      } else {
        val rows = cb.numRows()
        val handle = ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, cb)
        if (nativeShuffleWriter == -1L) {
          nativeShuffleWriter = jniWrapper.make(
            dep.nativePartitioning.getShortName,
            dep.nativePartitioning.getNumPartitions,
            nativeBufferSize,
            nativeMergeBufferSize,
            nativeMergeThreshold,
            compressionCodec.orNull,
            compressionCodecBackend.orNull,
            compressionLevel,
            compressionBufferSize,
            conf.get(SHUFFLE_DISK_WRITE_BUFFER_SIZE).toInt,
            bufferCompressThreshold,
            GlutenConfig.get.columnarShuffleCompressionMode,
            conf.get(SHUFFLE_SORT_INIT_BUFFER_SIZE).toInt,
            conf.get(SHUFFLE_SORT_USE_RADIXSORT),
            dataTmp.getAbsolutePath,
            blockManager.subDirsPerLocalDir,
            localDirs,
            reallocThreshold,
            handle,
            taskContext.taskAttemptId(),
            GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, taskContext.partitionId),
            shuffleWriterType
          )
          runtime
            .memoryManager()
            .addSpiller(new Spiller() {
              override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = {
                if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
                  return 0L
                }
                logInfo(s"Gluten shuffle writer: Trying to spill $size bytes of data")
                // fixme pass true when being called by self
                val spilled =
                  jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
                logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of data")
                spilled
              }
            })
        }
        val startTime = System.nanoTime()
        jniWrapper.write(nativeShuffleWriter, rows, handle, availableOffHeapPerTask())
        dep.metrics("shuffleWallTime").add(System.nanoTime() - startTime)
        dep.metrics("numInputRows").add(rows)
        dep.metrics("inputBatches").add(1)
        // This metric is important, AQE use it to decide if EliminateLimit
        writeMetrics.incRecordsWritten(rows)
      }
      cb.close()
    }

    if (nativeShuffleWriter == -1L) {
      handleEmptyInput()
      return
    }

    val startTime = System.nanoTime()
    assert(nativeShuffleWriter != -1L)
    splitResult = jniWrapper.stop(nativeShuffleWriter)
    closeShuffleWriter()
    dep.metrics("shuffleWallTime").add(System.nanoTime() - startTime)
    if (!isSort) {
      dep
        .metrics("splitTime")
        .add(
          dep.metrics("shuffleWallTime").value - splitResult.getTotalSpillTime -
            splitResult.getTotalWriteTime -
            splitResult.getTotalCompressTime)
    } else {
      dep.metrics("sortTime").add(splitResult.getSortTime)
      dep.metrics("c2rTime").add(splitResult.getC2RTime)
    }
    dep.metrics("spillTime").add(splitResult.getTotalSpillTime)
    dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled)
    dep.metrics("dataSize").add(splitResult.getRawPartitionLengths.sum)
    dep.metrics("compressTime").add(splitResult.getTotalCompressTime)
    dep.metrics("peakBytes").add(splitResult.getPeakBytes)
    writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten)
    writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime)
    taskContext.taskMetrics().incMemoryBytesSpilled(splitResult.getBytesToEvict)
    taskContext.taskMetrics().incDiskBytesSpilled(splitResult.getTotalBytesSpilled)

    partitionLengths = splitResult.getPartitionLengths
    try {
      shuffleBlockResolver.writeMetadataFileAndCommit(
        dep.shuffleId,
        mapId,
        partitionLengths,
        Array[Long](),
        dataTmp)
    } finally {
      if (dataTmp.exists() && !dataTmp.delete()) {
        logError(s"Error while deleting temp file ${dataTmp.getAbsolutePath}")
      }
    }

    // The partitionLength is much more than vanilla spark partitionLengths,
    // almost 3 times than vanilla spark partitionLengths
    // This value is sensitive in rules such as AQE rule OptimizeSkewedJoin DynamicJoinSelection
    // May affect the final plan
    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
  }