override def read()

in client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala [78:426]


  override def read(): Iterator[Product2[K, C]] = {

    val startTime = System.currentTimeMillis()
    val serializerInstance = newSerializerInstance(dep)
    val shuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle, context, false)
    shuffleIdTracker.track(handle.shuffleId, shuffleId)
    logDebug(
      s"get shuffleId $shuffleId for appShuffleId ${handle.shuffleId} attemptNum ${context.stageAttemptNumber()}")

    // Update the context task metrics for each record read.
    val metricsCallback = new MetricsCallback {
      override def incBytesRead(bytesWritten: Long): Unit = {
        metrics.incRemoteBytesRead(bytesWritten)
        metrics.incRemoteBlocksFetched(1)
      }

      override def incReadTime(time: Long): Unit =
        metrics.incFetchWaitTime(time)
    }

    if (streamCreatorPool == null) {
      CelebornShuffleReader.synchronized {
        if (streamCreatorPool == null) {
          streamCreatorPool = ThreadUtils.newDaemonCachedThreadPool(
            "celeborn-create-stream-thread",
            conf.readStreamCreatorPoolThreads,
            60)
        }
      }
    }

    val fetchTimeoutMs = conf.clientFetchTimeoutMs
    val localFetchEnabled = conf.enableReadLocalShuffleFile
    val localHostAddress = Utils.localHostName(conf)
    val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
    var fileGroups: ReduceFileGroups = null
    try {
      // startPartition is irrelevant
      fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
    } catch {
      case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
        // if a task is interrupted, should not report fetch failure
        // if a task update file group timeout, should not report fetch failure
        checkAndReportFetchFailureForUpdateFileGroupFailure(shuffleId, ce)
      case e: Throwable => throw e
    }

    val batchOpenStreamStartTime = System.currentTimeMillis()
    // host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList)
    val workerRequestMap = new JHashMap[
      String,
      (TransportClient, JArrayList[PartitionLocation], PbOpenStreamList.Builder)]()
    // partitionId -> (partition uniqueId -> chunkRange pair)
    val partitionId2ChunkRange = new JHashMap[Int, JMap[String, Pair[Integer, Integer]]]()

    val partitionId2PartitionLocations = new JHashMap[Int, JSet[PartitionLocation]]()

    var partCnt = 0

    // if startMapIndex > endMapIndex, means partition is skew partition and read by Celeborn implementation.
    // locations will split to sub-partitions with startMapIndex size.
    val splitSkewPartitionWithoutMapRange =
      ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex)

    (startPartition until endPartition).foreach { partitionId =>
      if (fileGroups.partitionGroups.containsKey(partitionId)) {
        var locations = fileGroups.partitionGroups.get(partitionId)
        if (splitSkewPartitionWithoutMapRange) {
          val partitionLocation2ChunkRange = CelebornPartitionUtil.splitSkewedPartitionLocations(
            new JArrayList(locations),
            startMapIndex,
            endMapIndex)
          partitionId2ChunkRange.put(partitionId, partitionLocation2ChunkRange)
          // filter locations avoid OPEN_STREAM when split skew partition without map range
          val filterLocations = locations.asScala
            .filter { location =>
              null != partitionLocation2ChunkRange &&
              partitionLocation2ChunkRange.containsKey(location.getUniqueId)
            }
          locations = filterLocations.asJava
          partitionId2PartitionLocations.put(partitionId, locations)
        }

        locations.asScala.foreach { location =>
          partCnt += 1
          val hostPort = location.hostAndFetchPort
          if (!workerRequestMap.containsKey(hostPort)) {
            try {
              val client = shuffleClient.getDataClientFactory().createClient(
                location.getHost,
                location.getFetchPort)
              val pbOpenStreamList = PbOpenStreamList.newBuilder()
              pbOpenStreamList.setShuffleKey(shuffleKey)
              workerRequestMap.put(
                hostPort,
                (client, new JArrayList[PartitionLocation], pbOpenStreamList))
            } catch {
              case ex: Exception =>
                shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex)
                logWarning(
                  s"Failed to create client for $shuffleKey-$partitionId from host: ${location.hostAndFetchPort}. " +
                    s"Shuffle reader will try its replica if exists.")
            }
          }
          workerRequestMap.get(hostPort) match {
            case (_, locArr, pbOpenStreamListBuilder) =>
              locArr.add(location)
              pbOpenStreamListBuilder.addFileName(location.getFileName)
                .addStartIndex(startMapIndex)
                .addEndIndex(endMapIndex)
              pbOpenStreamListBuilder.addReadLocalShuffle(
                localFetchEnabled && location.getHost.equals(localHostAddress))
            case _ =>
              logDebug(s"Empty client for host ${hostPort}")
          }
        }
      }
    }

    val locationStreamHandlerMap: ConcurrentHashMap[PartitionLocation, PbStreamHandler] =
      JavaUtils.newConcurrentHashMap()

    val futures = workerRequestMap.values().asScala.map { entry =>
      streamCreatorPool.submit(new Runnable {
        override def run(): Unit = {
          val (client, locArr, pbOpenStreamListBuilder) = entry
          val msg = new TransportMessage(
            MessageType.BATCH_OPEN_STREAM,
            pbOpenStreamListBuilder.build().toByteArray)
          val pbOpenStreamListResponse =
            try {
              val response = client.sendRpcSync(msg.toByteBuffer, fetchTimeoutMs)
              TransportMessage.fromByteBuffer(response).getParsedPayload[PbOpenStreamListResponse]
            } catch {
              case _: Exception => null
            }
          if (pbOpenStreamListResponse != null) {
            0 until locArr.size() foreach { idx =>
              val streamHandlerOpt = pbOpenStreamListResponse.getStreamHandlerOptList.get(idx)
              if (streamHandlerOpt.getStatus == StatusCode.SUCCESS.getValue) {
                locationStreamHandlerMap.put(locArr.get(idx), streamHandlerOpt.getStreamHandler)
              }
            }
          }
        }
      })
    }.toList
    // wait for all futures to complete
    futures.foreach(f => f.get())
    val end = System.currentTimeMillis()
    // readTime should include batchOpenStreamTime, getShuffleId Rpc time and updateFileGroup Rpc time
    metricsCallback.incReadTime(end - startTime)
    logInfo(s"BatchOpenStream for $partCnt cost ${end - batchOpenStreamStartTime}ms")

    val streams = JavaUtils.newConcurrentHashMap[Integer, CelebornInputStream]()

    def createInputStream(partitionId: Int): Unit = {
      val locations =
        if (splitSkewPartitionWithoutMapRange) {
          partitionId2PartitionLocations.get(partitionId)
        } else {
          fileGroups.partitionGroups.get(partitionId)
        }

      val locationList =
        if (null == locations) {
          new JArrayList[PartitionLocation]()
        } else {
          new JArrayList[PartitionLocation](locations)
        }
      val streamHandlers =
        if (locations != null) {
          val streamHandlerArr = new JArrayList[PbStreamHandler](locationList.size)
          locationList.asScala.foreach { loc =>
            streamHandlerArr.add(locationStreamHandlerMap.get(loc))
          }
          streamHandlerArr
        } else null
      if (exceptionRef.get() == null) {
        try {
          val inputStream = shuffleClient.readPartition(
            shuffleId,
            handle.shuffleId,
            partitionId,
            encodedAttemptId,
            context.taskAttemptId(),
            startMapIndex,
            endMapIndex,
            if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER
            else null,
            locationList,
            streamHandlers,
            fileGroups.pushFailedBatches,
            partitionId2ChunkRange.get(partitionId),
            fileGroups.mapAttempts,
            metricsCallback)
          streams.put(partitionId, inputStream)
        } catch {
          case e: IOException =>
            logError(s"Exception caught when readPartition $partitionId!", e)
            exceptionRef.compareAndSet(null, e)
          case e: Throwable =>
            logError(s"Non IOException caught when readPartition $partitionId!", e)
            exceptionRef.compareAndSet(null, new CelebornIOException(e))
        }
      }
    }

    val inputStreamCreationWindow = conf.clientInputStreamCreationWindow
    (startPartition until Math.min(
      startPartition + inputStreamCreationWindow,
      endPartition)).foreach(partitionId => {
      streamCreatorPool.submit(new Runnable {
        override def run(): Unit = {
          createInputStream(partitionId)
        }
      })
    })

    val recordIter = (startPartition until endPartition).iterator.map(partitionId => {
      if (handle.numMappers > 0) {
        val startFetchWait = System.nanoTime()
        var inputStream: CelebornInputStream = streams.get(partitionId)
        var sleepCnt = 0L
        while (inputStream == null) {
          if (exceptionRef.get() != null) {
            exceptionRef.get() match {
              case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
                handleFetchExceptions(handle.shuffleId, shuffleId, partitionId, ce)
              case e => throw e
            }
          }
          if (sleepCnt == 0) {
            logInfo(s"inputStream for partition: $partitionId is null, sleeping 5ms")
          }
          sleepCnt += 1
          Thread.sleep(5)
          inputStream = streams.get(partitionId)
        }
        if (sleepCnt > 0) {
          logInfo(
            s"inputStream for partition: $partitionId is not null, sleep $sleepCnt times for ${5 * sleepCnt} ms")
        }
        metricsCallback.incReadTime(
          TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait))
        // ensure inputStream is closed when task completes
        context.addTaskCompletionListener[Unit](_ => inputStream.close())

        // Advance the input creation window
        if (partitionId + inputStreamCreationWindow < endPartition) {
          streamCreatorPool.submit(new Runnable {
            override def run(): Unit = {
              createInputStream(partitionId + inputStreamCreationWindow)
            }
          })
        }

        (partitionId, inputStream)
      } else {
        (partitionId, CelebornInputStream.empty())
      }
    }).filter {
      case (_, inputStream) => inputStream != CelebornInputStream.empty()
    }.map { case (partitionId, inputStream) =>
      (partitionId, serializerInstance.deserializeStream(inputStream).asKeyValueIterator)
    }.flatMap { case (partitionId, iter) =>
      try {
        iter
      } catch {
        case e @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
          handleFetchExceptions(handle.shuffleId, shuffleId, partitionId, e)
      }
    }

    val iterWithUpdatedRecordsRead =
      recordIter.map { record =>
        metrics.incRecordsRead(1)
        record
      }

    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      iterWithUpdatedRecordsRead,
      context.taskMetrics().mergeShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val resultIter: Iterator[Product2[K, C]] = {
      // Sort the output if there is a sort ordering defined.
      if (dep.keyOrdering.isDefined) {
        // Create an ExternalSorter to sort the data.
        val sorter: ExternalSorter[K, _, C] =
          if (dep.aggregator.isDefined) {
            if (dep.mapSideCombine) {
              new ExternalSorter[K, C, C](
                context,
                Option(new Aggregator[K, C, C](
                  identity,
                  dep.aggregator.get.mergeCombiners,
                  dep.aggregator.get.mergeCombiners)),
                ordering = Some(dep.keyOrdering.get),
                serializer = dep.serializer)
            } else {
              new ExternalSorter[K, Nothing, C](
                context,
                dep.aggregator.asInstanceOf[Option[Aggregator[K, Nothing, C]]],
                ordering = Some(dep.keyOrdering.get),
                serializer = dep.serializer)
            }
          } else {
            new ExternalSorter[K, C, C](
              context,
              ordering = Some(dep.keyOrdering.get),
              serializer = dep.serializer)
          }
        sorter.insertAll(interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]])
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
        // Use completion callback to stop sorter if task was finished/cancelled.
        context.addTaskCompletionListener[Unit](_ => {
          sorter.stop()
        })
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      } else if (dep.aggregator.isDefined) {
        if (dep.mapSideCombine) {
          // We are reading values that are already combined
          val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
          dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
        } else {
          // We don't know the value type, but also don't care -- the dependency *should*
          // have made sure its compatible w/ this aggregator, which will convert the value
          // type to the combined type C
          val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
          dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
        }
      } else {
        interruptibleIter.asInstanceOf[Iterator[(K, C)]]
      }
    }

    resultIter match {
      case _: InterruptibleIterator[Product2[K, C]] => resultIter
      case _ =>
        // Use another interruptible iterator here to support task cancellation as aggregator
        // or(and) sorter may have consumed previous interruptible iterator.
        new InterruptibleIterator[Product2[K, C]](context, resultIter)
    }
  }