in gluten-data/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala [115:232]
def internalWrite(records: Iterator[Product2[K, V]]): Unit = {
if (!records.hasNext) {
partitionLengths = new Array[Long](dep.partitioner.numPartitions)
shuffleBlockResolver.writeMetadataFileAndCommit(
dep.shuffleId,
mapId,
partitionLengths,
Array[Long](),
null)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
return
}
val dataTmp = Utils.tempFileWith(shuffleBlockResolver.getDataFile(dep.shuffleId, mapId))
while (records.hasNext) {
val cb = records.next()._2.asInstanceOf[ColumnarBatch]
if (cb.numRows == 0 || cb.numCols == 0) {
logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols")
} else {
val rows = cb.numRows()
val handle = ColumnarBatches.getNativeHandle(cb)
if (nativeShuffleWriter == -1L) {
nativeShuffleWriter = jniWrapper.make(
dep.nativePartitioning,
nativeBufferSize,
nativeMergeBufferSize,
nativeMergeThreshold,
compressionCodec,
compressionCodecBackend,
bufferCompressThreshold,
GlutenConfig.getConf.columnarShuffleCompressionMode,
dataTmp.getAbsolutePath,
blockManager.subDirsPerLocalDir,
localDirs,
NativeMemoryManagers
.create(
"ShuffleWriter",
new Spiller() {
override def spill(self: MemoryTarget, size: Long): Long = {
if (nativeShuffleWriter == -1L) {
throw new IllegalStateException(
"Fatal: spill() called before a shuffle writer " +
"is created. This behavior should be optimized by moving memory " +
"allocations from make() to split()")
}
logInfo(s"Gluten shuffle writer: Trying to spill $size bytes of data")
// fixme pass true when being called by self
val spilled =
jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of data")
spilled
}
override def applicablePhases(): java.util.Set[Spiller.Phase] =
Spillers.PHASE_SET_SPILL_ONLY
}
)
.getNativeInstanceHandle,
reallocThreshold,
handle,
taskContext.taskAttemptId(),
GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, taskContext.partitionId)
)
}
val startTime = System.nanoTime()
val bytes = jniWrapper.split(nativeShuffleWriter, rows, handle, availableOffHeapPerTask())
dep.metrics("splitTime").add(System.nanoTime() - startTime)
dep.metrics("numInputRows").add(rows)
dep.metrics("inputBatches").add(1)
// This metric is important, AQE use it to decide if EliminateLimit
writeMetrics.incRecordsWritten(rows)
}
cb.close()
}
val startTime = System.nanoTime()
if (nativeShuffleWriter != -1L) {
splitResult = jniWrapper.stop(nativeShuffleWriter)
closeShuffleWriter
}
dep
.metrics("splitTime")
.add(
System.nanoTime() - startTime - splitResult.getTotalSpillTime -
splitResult.getTotalWriteTime -
splitResult.getTotalCompressTime)
dep.metrics("spillTime").add(splitResult.getTotalSpillTime)
dep.metrics("compressTime").add(splitResult.getTotalCompressTime)
dep.metrics("bytesSpilled").add(splitResult.getTotalBytesSpilled)
dep.metrics("splitBufferSize").add(splitResult.getSplitBufferSize)
dep.metrics("uncompressedDataSize").add(splitResult.getRawPartitionLengths.sum)
dep.metrics("dataSize").add(splitResult.getRawPartitionLengths.sum)
writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten)
writeMetrics.incWriteTime(splitResult.getTotalWriteTime + splitResult.getTotalSpillTime)
partitionLengths = splitResult.getPartitionLengths
rawPartitionLengths = splitResult.getRawPartitionLengths
try {
shuffleBlockResolver.writeMetadataFileAndCommit(
dep.shuffleId,
mapId,
partitionLengths,
Array[Long](),
dataTmp)
} finally {
if (dataTmp.exists() && !dataTmp.delete()) {
logError(s"Error while deleting temp file ${dataTmp.getAbsolutePath}")
}
}
// The partitionLength is much more than vanilla spark partitionLengths,
// almost 3 times than vanilla spark partitionLengths
// This value is sensitive in rules such as AQE rule OptimizeSkewedJoin DynamicJoinSelection
// May affect the final plan
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}