private def fetchNextDeserializationIterator()

in src/main/scala/org/apache/spark/shuffle/rss/BlockDownloaderPartitionRecordIterator.scala [131:198]


  private def fetchNextDeserializationIterator(): Unit = {
    clearDeserializationStream()

    val readRecordStartNanoTime = System.nanoTime()
    var dataBlock: TaskDataBlock = null;

    try {
      dataBlock = downloader.readDataBlock()
      fetchNanoTime += System.nanoTime() - readRecordStartNanoTime

      while (dataBlock != null &&
        (dataBlock.getPayload == null || dataBlock.getPayload.size == 0)) {
        val readRecordStartNanoTime = System.nanoTime()
        dataBlock = downloader.readDataBlock()
        fetchNanoTime += System.nanoTime() - readRecordStartNanoTime
      }
    } catch {
      case ex: Throwable => {
        downloader.close()
        M3Stats.addException(ex, this.getClass().getSimpleName())
        throw new FetchFailedException(
          RssUtils.createReduceTaskDummyBlockManagerId(shuffleId, partition),
          shuffleId,
          -1,
          partition,
          s"Failed to read data fro shuffle $shuffleId partition $partition due to ${ExceptionUtils.getSimpleMessage(ex)})",
          ex)
      }
    }

    numRemoteBytesRead = downloader.getShuffleReadBytes

    if (dataBlock == null) {
      downloaderEof = true
      downloader.close()
      deserializationIterator = null
      return
    }

    val decompressStartTime = System.nanoTime()
    val bytes = dataBlock.getPayload
    val compressedLen = ByteBufUtils.readInt(bytes, 0)
    val uncompressedLen = ByteBufUtils.readInt(bytes, Integer.BYTES)
    val uncompressedBytes = new Array[Byte](uncompressedLen)
    if (Compression.COMPRESSION_CODEC_ZSTD.equals(decompression)) {
      // TODO Zstd in Spark 2.4 does not support decompress method with a range from source byte array
      // Better to use Zstd.decompressByteArray for Spark version higher than 2.4 to avoid copying bytes
      val sourceBytes = util.Arrays.copyOfRange(bytes, Integer.BYTES + Integer.BYTES, bytes.length)
      val n = Zstd.decompress(uncompressedBytes, sourceBytes)
      if (Zstd.isError(n)) {
        throw new RssInvalidDataException(
          s"Data corrupted for shuffle $shuffleId partition $partition, failed to decompress zstd, decompress returned: $n, " + String.valueOf(downloader))
      }
    } else {
      val count = lz4Decompressor.decompress(bytes, Integer.BYTES + Integer.BYTES, uncompressedBytes, 0, uncompressedLen)
      if (count != compressedLen) {
        throw new RssInvalidDataException(
          s"Data corrupted for shuffle $shuffleId partition $partition, expected compressed length: $compressedLen, but it is: $count, " + String.valueOf(downloader))
      }
    }
    decompressTime += (System.nanoTime() - decompressStartTime)

    deserializationInput = new Input(uncompressedBytes, 0, uncompressedLen)
    deserializationStream = serializerInstance.deserializeStream(deserializationInput)
    deserializationIterator = deserializationStream.asKeyValueIterator

    logShuffleFetchInfo(false)
  }