in backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala [132:249]
def internalWrite(records: Iterator[Product2[K, V]]): Unit = {
if (!records.hasNext) {
handleEmptyInput()
return
}
val dataTmp = Utils.tempFileWith(shuffleBlockResolver.getDataFile(dep.shuffleId, mapId))
while (records.hasNext) {
val cb = records.next()._2.asInstanceOf[ColumnarBatch]
if (cb.numRows == 0 || cb.numCols == 0) {
logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols")
} else {
val rows = cb.numRows()
val handle = ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, cb)
if (nativeShuffleWriter == -1L) {
nativeShuffleWriter = jniWrapper.make(
dep.nativePartitioning.getShortName,
dep.nativePartitioning.getNumPartitions,
nativeBufferSize,
nativeMergeBufferSize,
nativeMergeThreshold,
compressionCodec.orNull,
compressionCodecBackend.orNull,
compressionLevel,
compressionBufferSize,
conf.get(SHUFFLE_DISK_WRITE_BUFFER_SIZE).toInt,
bufferCompressThreshold,
GlutenConfig.get.columnarShuffleCompressionMode,
conf.get(SHUFFLE_SORT_INIT_BUFFER_SIZE).toInt,
conf.get(SHUFFLE_SORT_USE_RADIXSORT),
dataTmp.getAbsolutePath,
blockManager.subDirsPerLocalDir,
localDirs,
reallocThreshold,
handle,
taskContext.taskAttemptId(),
GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, taskContext.partitionId),
shuffleWriterType
)
runtime
.memoryManager()
.addSpiller(new Spiller() {
override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = {
if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
return 0L
}
logInfo(s"Gluten shuffle writer: Trying to spill $size bytes of data")
// fixme pass true when being called by self
val spilled =
jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of data")
spilled
}
})
}
val startTime = System.nanoTime()
jniWrapper.write(nativeShuffleWriter, rows, handle, availableOffHeapPerTask())
dep.metrics("shuffleWallTime").add(System.nanoTime() - startTime)
dep.metrics("numInputRows").add(rows)
dep.metrics("inputBatches").add(1)
// This metric is important, AQE use it to decide if EliminateLimit
writeMetrics.incRecordsWritten(rows)
}
cb.close()
}
if (nativeShuffleWriter == -1L) {
handleEmptyInput()
return
}
val startTime = System.nanoTime()
assert(nativeShuffleWriter != -1L)
splitResult = jniWrapper.stop(nativeShuffleWriter)
closeShuffleWriter()
dep.metrics("shuffleWallTime").add(System.nanoTime() - startTime)
if (!isSort) {
dep
.metrics("splitTime")
.add(
dep.metrics("shuffleWallTime").value - splitResult.getTotalSpillTime -
splitResult.getTotalWriteTime -
splitResult.getTotalCompressTime)
} else {
dep.metrics("sortTime").add(splitResult.getSortTime)
dep.metrics("c2rTime").add(splitResult.getC2RTime)
}
dep.metrics("spillTime").add(splitResult.getTotalSpillTime)
dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled)
dep.metrics("dataSize").add(splitResult.getRawPartitionLengths.sum)
dep.metrics("compressTime").add(splitResult.getTotalCompressTime)
dep.metrics("peakBytes").add(splitResult.getPeakBytes)
writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten)
writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime)
taskContext.taskMetrics().incMemoryBytesSpilled(splitResult.getBytesToEvict)
taskContext.taskMetrics().incDiskBytesSpilled(splitResult.getTotalBytesSpilled)
partitionLengths = splitResult.getPartitionLengths
try {
shuffleBlockResolver.writeMetadataFileAndCommit(
dep.shuffleId,
mapId,
partitionLengths,
Array[Long](),
dataTmp)
} finally {
if (dataTmp.exists() && !dataTmp.delete()) {
logError(s"Error while deleting temp file ${dataTmp.getAbsolutePath}")
}
}
// The partitionLength is much more than vanilla spark partitionLengths,
// almost 3 times than vanilla spark partitionLengths
// This value is sensitive in rules such as AQE rule OptimizeSkewedJoin DynamicJoinSelection
// May affect the final plan
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}