in spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala [99:170]
private def fetchNext(): Option[ColumnarBatch] = {
if (channel == null || isClosed) {
return None
}
// read compressed batch size from header
try {
longBuf.clear()
while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {}
} catch {
case _: EOFException =>
close()
return None
}
// If we reach the end of the stream, we are done, or if we read partial length
// then the stream is corrupted.
if (longBuf.hasRemaining) {
if (longBuf.position() == 0) {
close()
return None
}
throw new EOFException("Data corrupt: unexpected EOF while reading compressed ipc lengths")
}
// get compressed length (including headers)
longBuf.flip()
val compressedLength = longBuf.getLong
// read field count from header
longBuf.clear()
while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {}
if (longBuf.hasRemaining) {
throw new EOFException("Data corrupt: unexpected EOF while reading field count")
}
longBuf.flip()
val fieldCount = longBuf.getLong.toInt
// read body
val bytesToRead = compressedLength - 8
if (bytesToRead > Integer.MAX_VALUE) {
// very unlikely that shuffle block will reach 2GB
throw new IllegalStateException(
s"Native shuffle block size of $bytesToRead exceeds " +
s"maximum of ${Integer.MAX_VALUE}. Try reducing shuffle batch size.")
}
var dataBuf = threadLocalDataBuf.get()
if (dataBuf.capacity() < bytesToRead) {
// it is unlikely that we would overflow here since it would
// require a 1GB compressed shuffle block but we check anyway
val newCapacity = (bytesToRead * 2L).min(Integer.MAX_VALUE).toInt
dataBuf = ByteBuffer.allocateDirect(newCapacity)
threadLocalDataBuf.set(dataBuf)
}
dataBuf.clear()
dataBuf.limit(bytesToRead.toInt)
while (dataBuf.hasRemaining && channel.read(dataBuf) >= 0) {}
if (dataBuf.hasRemaining) {
throw new EOFException("Data corrupt: unexpected EOF while reading compressed batch")
}
// make native call to decode batch
val startTime = System.nanoTime()
val batch = nativeUtil.getNextBatch(
fieldCount,
(arrayAddrs, schemaAddrs) => {
native.decodeShuffleBlock(dataBuf, bytesToRead.toInt, arrayAddrs, schemaAddrs)
})
decodeTime.add(System.nanoTime() - startTime)
batch
}