in src/main/scala/org/apache/spark/shuffle/rss/BlockDownloaderPartitionRecordIterator.scala [131:198]
private def fetchNextDeserializationIterator(): Unit = {
clearDeserializationStream()
val readRecordStartNanoTime = System.nanoTime()
var dataBlock: TaskDataBlock = null;
try {
dataBlock = downloader.readDataBlock()
fetchNanoTime += System.nanoTime() - readRecordStartNanoTime
while (dataBlock != null &&
(dataBlock.getPayload == null || dataBlock.getPayload.size == 0)) {
val readRecordStartNanoTime = System.nanoTime()
dataBlock = downloader.readDataBlock()
fetchNanoTime += System.nanoTime() - readRecordStartNanoTime
}
} catch {
case ex: Throwable => {
downloader.close()
M3Stats.addException(ex, this.getClass().getSimpleName())
throw new FetchFailedException(
RssUtils.createReduceTaskDummyBlockManagerId(shuffleId, partition),
shuffleId,
-1,
partition,
s"Failed to read data fro shuffle $shuffleId partition $partition due to ${ExceptionUtils.getSimpleMessage(ex)})",
ex)
}
}
numRemoteBytesRead = downloader.getShuffleReadBytes
if (dataBlock == null) {
downloaderEof = true
downloader.close()
deserializationIterator = null
return
}
val decompressStartTime = System.nanoTime()
val bytes = dataBlock.getPayload
val compressedLen = ByteBufUtils.readInt(bytes, 0)
val uncompressedLen = ByteBufUtils.readInt(bytes, Integer.BYTES)
val uncompressedBytes = new Array[Byte](uncompressedLen)
if (Compression.COMPRESSION_CODEC_ZSTD.equals(decompression)) {
// TODO Zstd in Spark 2.4 does not support decompress method with a range from source byte array
// Better to use Zstd.decompressByteArray for Spark version higher than 2.4 to avoid copying bytes
val sourceBytes = util.Arrays.copyOfRange(bytes, Integer.BYTES + Integer.BYTES, bytes.length)
val n = Zstd.decompress(uncompressedBytes, sourceBytes)
if (Zstd.isError(n)) {
throw new RssInvalidDataException(
s"Data corrupted for shuffle $shuffleId partition $partition, failed to decompress zstd, decompress returned: $n, " + String.valueOf(downloader))
}
} else {
val count = lz4Decompressor.decompress(bytes, Integer.BYTES + Integer.BYTES, uncompressedBytes, 0, uncompressedLen)
if (count != compressedLen) {
throw new RssInvalidDataException(
s"Data corrupted for shuffle $shuffleId partition $partition, expected compressed length: $compressedLen, but it is: $count, " + String.valueOf(downloader))
}
}
decompressTime += (System.nanoTime() - decompressStartTime)
deserializationInput = new Input(uncompressedBytes, 0, uncompressedLen)
deserializationStream = serializerInstance.deserializeStream(deserializationInput)
deserializationIterator = deserializationStream.asKeyValueIterator
logShuffleFetchInfo(false)
}