override def next()

in core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala [799:1093]


  override def next(): (BlockId, InputStream) = {
    if (!hasNext) {
      throw SparkCoreErrors.noSuchElementError()
    }

    numBlocksProcessed += 1

    var result: FetchResult = null
    var input: InputStream = null
    // This's only initialized when shuffle checksum is enabled.
    var checkedIn: CheckedInputStream = null
    var streamCompressedOrEncrypted: Boolean = false
    // Take the next fetched result and try to decompress it to detect data corruption,
    // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
    // is also corrupt, so the previous stage could be retried.
    // For local shuffle block, throw FailureFetchResult for the first IOException.
    while (result == null) {
      val startFetchWait = System.nanoTime()
      result = results.take()
      val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait)
      shuffleMetrics.incFetchWaitTime(fetchWaitTime)

      result match {
        case SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) =>
          if (address != blockManager.blockManagerId) {
            if (hostLocalBlocks.contains(blockId -> mapIndex) ||
              pushBasedFetchHelper.isLocalPushMergedBlockAddress(address)) {
              // It is a host local block or a local shuffle chunk
              shuffleMetricsUpdate(blockId, buf, local = true)
            } else {
              numBlocksInFlightPerAddress(address) -= 1
              shuffleMetricsUpdate(blockId, buf, local = false)
              bytesInFlight -= size
            }
          }
          if (isNetworkReqDone) {
            reqsInFlight -= 1
            resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem)
            logDebug("Number of requests in flight " + reqsInFlight)
          }

          val in = if (buf.size == 0) {
            // We will never legitimately receive a zero-size block. All blocks with zero records
            // have zero size and all zero-size blocks have no records (and hence should never
            // have been requested in the first place). This statement relies on behaviors of the
            // shuffle writers, which are guaranteed by the following test cases:
            //
            // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions"
            // - UnsafeShuffleWriterSuite: "writeEmptyIterator"
            // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing"
            //
            // There is not an explicit test for SortShuffleWriter but the underlying APIs that
            // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter
            // which returns a zero-size from commitAndGet() in case no records were written
            // since the last call.
            val msg = log"Received a zero-size buffer for block ${MDC(BLOCK_ID, blockId)} " +
              log"from ${MDC(URI, address)} " +
              log"(expectedApproxSize = ${MDC(NUM_BYTES, size)}, " +
              log"isNetworkReqDone=${MDC(IS_NETWORK_REQUEST_DONE, isNetworkReqDone)})"
            if (blockId.isShuffleChunk) {
              // Zero-size block may come from nodes with hardware failures, For shuffle chunks,
              // the original shuffle blocks that belong to that zero-size shuffle chunk is
              // available and we can opt to fallback immediately.
              logWarning(msg)
              pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
              shuffleMetrics.incCorruptMergedBlockChunks(1)
              // Set result to null to trigger another iteration of the while loop to get either.
              result = null
              null
            } else {
              throwFetchFailedException(blockId, mapIndex, address, new IOException(msg.message))
            }
          } else {
            try {
              val bufIn = buf.createInputStream()
              if (checksumEnabled) {
                val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm)
                checkedIn = new CheckedInputStream(bufIn, checksum)
                checkedIn
              } else {
                bufIn
              }
            } catch {
              // The exception could only be throwed by local shuffle block
              case e: IOException =>
                assert(buf.isInstanceOf[FileSegmentManagedBuffer])
                e match {
                  case ce: ClosedByInterruptException =>
                    lazy val error = MDC(ERROR, ce.getMessage)
                    logError(log"Failed to create input stream from local block, $error")
                  case e: IOException =>
                    logError("Failed to create input stream from local block", e)
                }
                buf.release()
                if (blockId.isShuffleChunk) {
                  pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
                  // Set result to null to trigger another iteration of the while loop to get
                  // either.
                  result = null
                  null
                } else {
                  throwFetchFailedException(blockId, mapIndex, address, e)
                }
            }
          }

          if (in != null) {
            try {
              input = streamWrapper(blockId, in)
              // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
              // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
              // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
              // the corruption is later, we'll still detect the corruption later in the stream.
              streamCompressedOrEncrypted = !input.eq(in)
              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
                // TODO: manage the memory used here, and spill it into disk in case of OOM.
                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
              }
            } catch {
              case e: IOException =>
                // When shuffle checksum is enabled, for a block that is corrupted twice,
                // we'd calculate the checksum of the block by consuming the remaining data
                // in the buf. So, we should release the buf later.
                if (!(checksumEnabled && corruptedBlocks.contains(blockId))) {
                  buf.release()
                }

                if (blockId.isShuffleChunk) {
                  shuffleMetrics.incCorruptMergedBlockChunks(1)
                  // TODO (SPARK-36284): Add shuffle checksum support for push-based shuffle
                  // Retrying a corrupt block may result again in a corrupt block. For shuffle
                  // chunks, we opt to fallback on the original shuffle blocks that belong to that
                  // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt
                  // chunk. This also makes the code simpler because the chunkMeta corresponding to
                  // a shuffle chunk is always removed from chunksMetaMap whenever a shuffle chunk
                  // gets processed. If we try to re-fetch a corrupt shuffle chunk, then it has to
                  // be added back to the chunksMetaMap.
                  pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
                  // Set result to null to trigger another iteration of the while loop.
                  result = null
                } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
                  throwFetchFailedException(blockId, mapIndex, address, e)
                } else if (corruptedBlocks.contains(blockId)) {
                  // It's the second time this block is detected corrupted
                  if (checksumEnabled) {
                    // Diagnose the cause of data corruption if shuffle checksum is enabled
                    val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId)
                    buf.release()
                    logError(diagnosisResponse)
                    throwFetchFailedException(
                      blockId, mapIndex, address, e, Some(diagnosisResponse))
                  } else {
                    throwFetchFailedException(blockId, mapIndex, address, e)
                  }
                } else {
                  // It's the first time this block is detected corrupted
                  logWarning(log"got an corrupted block ${MDC(BLOCK_ID, blockId)} " +
                    log"from ${MDC(URI, address)}, fetch again", e)
                  corruptedBlocks += blockId
                  fetchRequests += FetchRequest(
                    address, Array(FetchBlockInfo(blockId, size, mapIndex)))
                  result = null
                }
            } finally {
              if (blockId.isShuffleChunk) {
                pushBasedFetchHelper.removeChunk(blockId.asInstanceOf[ShuffleBlockChunkId])
              }
              // TODO: release the buf here to free memory earlier
              if (input == null) {
                // Close the underlying stream if there was an issue in wrapping the stream using
                // streamWrapper
                in.close()
              }
            }
          }

        case FailureFetchResult(blockId, mapIndex, address, e) =>
          var errorMsg: String = null
          if (e.isInstanceOf[OutOfDirectMemoryError]) {
            val logMessage = log"Block ${MDC(BLOCK_ID, blockId)} fetch failed after " +
              log"${MDC(MAX_ATTEMPTS, maxAttemptsOnNettyOOM)} retries due to Netty OOM"
            logError(logMessage)
            errorMsg = logMessage.message
          }
          throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg))

        case DeferFetchRequestResult(request) =>
          val address = request.address
          numBlocksInFlightPerAddress(address) -= request.blocks.size
          bytesInFlight -= request.size
          reqsInFlight -= 1
          logDebug("Number of requests in flight " + reqsInFlight)
          val defReqQueue =
            deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
          defReqQueue.enqueue(request)
          result = null

        case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) =>
          // We get this result in 3 cases:
          // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the
          //    blockId is a ShuffleBlockChunkId.
          // 2. Failure to read the push-merged-local meta. In this case, the blockId is
          //    ShuffleBlockId.
          // 3. Failure to get the push-merged-local directories from the external shuffle service.
          //    In this case, the blockId is ShuffleBlockId.
          if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) {
            numBlocksInFlightPerAddress(address) -= 1
            bytesInFlight -= size
          }
          if (isNetworkReqDone) {
            reqsInFlight -= 1
            logDebug("Number of requests in flight " + reqsInFlight)
          }
          pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
          // Set result to null to trigger another iteration of the while loop to get either
          // a SuccessFetchResult or a FailureFetchResult.
          result = null

          case PushMergedLocalMetaFetchResult(
            shuffleId, shuffleMergeId, reduceId, bitmaps, localDirs) =>
            // Fetch push-merged-local shuffle block data as multiple shuffle chunks
            val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId)
            try {
              val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId,
                localDirs)
              // Since the request for local block meta completed successfully, numBlocksToFetch
              // is decremented.
              numBlocksToFetch -= 1
              // Update total number of blocks to fetch, reflecting the multiple local shuffle
              // chunks.
              numBlocksToFetch += bufs.size
              bufs.zipWithIndex.foreach { case (buf, chunkId) =>
                buf.retain()
                val shuffleChunkId = ShuffleBlockChunkId(shuffleId, shuffleMergeId, reduceId,
                  chunkId)
                pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId))
                results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID,
                  pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf,
                  isNetworkReqDone = false))
              }
            } catch {
              case e: Exception =>
                // If we see an exception with reading push-merged-local index file, we fallback
                // to fetch the original blocks. We do not report block fetch failure
                // and will continue with the remaining local block read.
                logWarning("Error occurred while reading push-merged-local index, " +
                  "prepare to fetch the original blocks", e)
                pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(
                  shuffleBlockId, pushBasedFetchHelper.localShuffleMergerBlockMgrId)
            }
            result = null

        case PushMergedRemoteMetaFetchResult(
          shuffleId, shuffleMergeId, reduceId, blockSize, bitmaps, address) =>
          // The original meta request is processed so we decrease numBlocksToFetch and
          // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the
          // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks.
          numBlocksInFlightPerAddress(address) -= 1
          numBlocksToFetch -= 1
          val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse(
            shuffleId, shuffleMergeId, reduceId, blockSize, bitmaps)
          val additionalRemoteReqs = new ArrayBuffer[FetchRequest]
          collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs)
          fetchRequests ++= additionalRemoteReqs
          // Set result to null to force another iteration.
          result = null

        case PushMergedRemoteMetaFailedFetchResult(
          shuffleId, shuffleMergeId, reduceId, address) =>
          // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1.
          numBlocksInFlightPerAddress(address) -= 1
          // If we fail to fetch the meta of a push-merged block, we fall back to fetching the
          // original blocks.
          pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(
            ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId), address)
          // Set result to null to force another iteration.
          result = null
      }

      // Send fetch requests up to maxBytesInFlight
      fetchUpToMaxBytes()
    }

    currentResult = result.asInstanceOf[SuccessFetchResult]
    (currentResult.blockId,
      new BufferReleasingInputStream(
        input,
        this,
        currentResult.blockId,
        currentResult.mapIndex,
        currentResult.address,
        detectCorrupt && streamCompressedOrEncrypted,
        currentResult.isNetworkReqDone,
        Option(checkedIn)))
  }