def handlePushMergedData()

in worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala [447:818]


  def handlePushMergedData(
      pushMergedData: PushMergedData,
      callback: RpcResponseCallback): Unit = {
    val shuffleKey = pushMergedData.shuffleKey
    val mode = PartitionLocation.getMode(pushMergedData.mode)
    val batchOffsets = pushMergedData.batchOffsets
    val body = pushMergedData.body.asInstanceOf[NettyManagedBuffer].getBuf
    val isPrimary = mode == PartitionLocation.Mode.PRIMARY
    val (mapId, attemptId) = getMapAttempt(body)

    val key = s"${pushMergedData.requestId}"
    val callbackWithTimer =
      if (isPrimary) {
        new RpcResponseCallbackWithTimer(
          workerSource,
          WorkerSource.PRIMARY_PUSH_DATA_TIME,
          key,
          callback)
      } else {
        new RpcResponseCallbackWithTimer(
          workerSource,
          WorkerSource.REPLICA_PUSH_DATA_TIME,
          key,
          callback)
      }
    val pushMergedDataCallback = new PushMergedDataCallback(callbackWithTimer)

    // For test
    if (isPrimary && testPushPrimaryDataTimeout &&
      !PushDataHandler.pushPrimaryMergeDataTimeoutTested.getAndSet(true)) {
      return
    }

    if (!isPrimary && testPushReplicaDataTimeout &&
      !PushDataHandler.pushReplicaMergeDataTimeoutTested.getAndSet(true)) {
      return
    }

    val partitionIdToLocations =
      if (isPrimary) {
        partitionLocationInfo.getPrimaryLocations(shuffleKey, pushMergedData.partitionUniqueIds)
      } else {
        partitionLocationInfo.getReplicaLocations(shuffleKey, pushMergedData.partitionUniqueIds)
      }

    // Fetch real batchId from body will add more cost and no meaning for replicate.
    val doReplicate =
      partitionIdToLocations.head._2 != null && partitionIdToLocations.head._2.hasPeer && isPrimary

    // find FileWriters responsible for the data
    var index = 0
    while (index < partitionIdToLocations.length) {
      val (id, loc) = partitionIdToLocations(index)
      if (loc == null) {
        // MapperAttempts for a shuffle exists after any CommitFiles request succeeds.
        // A shuffle can trigger multiple CommitFiles requests, for reasons like: HARD_SPLIT happens, StageEnd.
        // If MapperAttempts but the value is -1 for the mapId(-1 means the map has not yet finished),
        // it's probably because commitFiles for HARD_SPLIT happens.
        if (shuffleMapperAttempts.containsKey(shuffleKey)) {
          if (-1 != shuffleMapperAttempts.get(shuffleKey).get(mapId)) {
            logDebug(s"Receive push merged data from speculative " +
              s"task(shuffle $shuffleKey, map $mapId, attempt $attemptId), " +
              s"but this mapper has already been ended.")
            pushMergedDataCallback.onSuccess(StatusCode.MAP_ENDED)
            return
          } else {
            logDebug(s"[Case1] Receive push merged data for committed hard split partition of " +
              s"(shuffle $shuffleKey, map $mapId attempt $attemptId)")
            workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
            pushMergedDataCallback.addSplitPartition(index, StatusCode.HARD_SPLIT)
          }
        } else {
          if (storageManager.shuffleKeySet().contains(shuffleKey)) {
            // If there is no shuffle key in shuffleMapperAttempts but there is shuffle key
            // in StorageManager. This partition should be HARD_SPLIT partition and
            // after worker restart, some tasks still push data to this HARD_SPLIT partition.
            logDebug(s"[Case2] Receive push merged data for committed hard split partition of " +
              s"(shuffle $shuffleKey, map $mapId attempt $attemptId)")
            workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
            pushMergedDataCallback.addSplitPartition(index, StatusCode.HARD_SPLIT)
          } else {
            logWarning(s"While handling PushMergedData, Partition location wasn't found for " +
              s"task(shuffle $shuffleKey, map $mapId, attempt $attemptId, uniqueId $id).")
            pushMergedDataCallback.onFailure(
              new CelebornIOException(StatusCode.PUSH_DATA_FAIL_PARTITION_NOT_FOUND))
            return
          }
        }
      }
      index += 1
    }

    // During worker shutdown, worker will return HARD_SPLIT for all existed partition.
    // This should before return exception to make current push data can revive and retry.
    if (shutdown.get()) {
      partitionIdToLocations.indices.foreach(index =>
        pushMergedDataCallback.addSplitPartition(index, StatusCode.HARD_SPLIT))
      pushMergedDataCallback.onSuccess(StatusCode.HARD_SPLIT)
      return
    }

    val (fileWriters, exceptionFileWriterIndexOpt) = getFileWriters(partitionIdToLocations)
    if (exceptionFileWriterIndexOpt.isDefined) {
      val fileWriterWithException = fileWriters(exceptionFileWriterIndexOpt.get)
      val cause =
        if (isPrimary) {
          StatusCode.PUSH_DATA_WRITE_FAIL_PRIMARY
        } else {
          StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA
        }
      logError(
        s"While handling PushMergedData, throw $cause, fileWriter $fileWriterWithException has exception.",
        fileWriterWithException.getException)
      workerSource.incCounter(WorkerSource.WRITE_DATA_FAIL_COUNT)
      pushMergedDataCallback.onFailure(new CelebornIOException(cause))
      return
    }

    var fileWriterIndex = 0
    val totalFileWriters = fileWriters.length
    while (fileWriterIndex < totalFileWriters) {
      val fileWriter = fileWriters(fileWriterIndex)
      if (fileWriter == null) {
        if (!pushMergedDataCallback.isHardSplitPartition(fileWriterIndex)) {
          pushMergedDataCallback.onFailure(
            new CelebornIOException(s"Partition $fileWriterIndex's fileWriter not found," +
              s" but it hasn't been identified in the previous validation step."))
          return
        }
      } else {
        if (fileWriter.isClosed) {
          val fileInfo = fileWriter.getCurrentFileInfo
          logWarning(
            s"[handlePushMergedData] FileWriter is already closed! File path ${fileInfo.getFilePath} " +
              s"length ${fileInfo.getFileLength}")
          pushMergedDataCallback.addSplitPartition(fileWriterIndex, StatusCode.HARD_SPLIT)
        } else {
          val splitStatus = checkDiskFullAndSplit(fileWriter, isPrimary)
          if (splitStatus == StatusCode.HARD_SPLIT) {
            logWarning(
              s"return hard split for disk full with shuffle $shuffleKey map $mapId attempt $attemptId")
            workerSource.incCounter(WorkerSource.WRITE_DATA_HARD_SPLIT_COUNT)
            pushMergedDataCallback.addSplitPartition(fileWriterIndex, StatusCode.HARD_SPLIT)
          } else if (splitStatus == StatusCode.SOFT_SPLIT) {
            pushMergedDataCallback.addSplitPartition(fileWriterIndex, StatusCode.SOFT_SPLIT)
          }
        }
        if (!pushMergedDataCallback.isHardSplitPartition(fileWriterIndex)) {
          fileWriter.incrementPendingWrites()
        }
      }
      fileWriterIndex += 1
    }

    val hardSplitIndexes = pushMergedDataCallback.getHardSplitIndexes
    val writePromise = Promise[Array[StatusCode]]()
    // for primary, send data to replica
    if (doReplicate) {
      val location = partitionIdToLocations.head._2
      val peer = location.getPeer
      val peerWorker = new WorkerInfo(
        peer.getHost,
        peer.getRpcPort,
        peer.getPushPort,
        peer.getFetchPort,
        peer.getReplicatePort)
      if (unavailablePeers.containsKey(peerWorker)) {
        for (fileWriterIndex <- 0 until totalFileWriters) {
          val fileWriter = fileWriters(fileWriterIndex)
          if (fileWriter != null && !pushMergedDataCallback.isHardSplitPartition(fileWriterIndex)) {
            fileWriter.decrementPendingWrites()
          }
        }
        handlePushMergedDataConnectionFail(pushMergedDataCallback, location)
        return
      }
      pushMergedData.body().retain()
      replicateThreadPool.submit(new Runnable {
        override def run(): Unit = {
          if (unavailablePeers.containsKey(peerWorker)) {
            pushMergedData.body().release()
            handlePushMergedDataConnectionFail(pushMergedDataCallback, location)
            return
          }
          // Handle the response from replica
          val wrappedCallback = new RpcResponseCallback() {
            override def onSuccess(response: ByteBuffer): Unit = {
              // During the rolling upgrade of the worker cluster, it is possible for
              // the primary worker to be upgraded to a new version that includes
              // the changes from [CELEBORN-1721], while the replica worker is still running
              // on an older version that does not have these changes.
              // In this scenario, the replica may return a response without any context
              // when status of SUCCESS.
              val replicaReason =
                if (response.remaining() > 0) {
                  response.get()
                } else {
                  StatusCode.SUCCESS
                }
              if (replicaReason == StatusCode.HARD_SPLIT.getValue) {
                if (response.remaining() > 0) {
                  try {
                    val pushMergedDataResponse: PbPushMergedDataSplitPartitionInfo =
                      TransportMessage.fromByteBuffer(
                        response).getParsedPayload[PbPushMergedDataSplitPartitionInfo]()
                    pushMergedDataCallback.unionReplicaSplitPartitions(
                      pushMergedDataResponse.getSplitPartitionIndexesList,
                      pushMergedDataResponse.getStatusCodesList)
                  } catch {
                    case e: CelebornIOException =>
                      pushMergedDataCallback.onFailure(e)
                      return
                    case e: IllegalArgumentException =>
                      pushMergedDataCallback.onFailure(new CelebornIOException(e))
                      return
                  }
                } else {
                  // During the rolling upgrade of the worker cluster, it is possible for the primary worker
                  // to be upgraded to a new version that includes the changes from [CELEBORN-1721], while
                  // the replica worker is still running on an older version that does not have these changes.
                  // In this scenario, the replica may return a response with a status of HARD_SPLIT, but
                  // will not provide a PbPushMergedDataSplitPartitionInfo.
                  logWarning(
                    s"The response status from the replica (shuffle $shuffleKey map $mapId attempt $attemptId) is HARD_SPLIT, but no PbPushMergedDataSplitPartitionInfo is present.")
                  partitionIdToLocations.indices.foreach(index =>
                    pushMergedDataCallback.addSplitPartition(index, StatusCode.HARD_SPLIT))
                  pushMergedDataCallback.onSuccess(StatusCode.HARD_SPLIT)
                  return
                }
              }
              Try(Await.result(writePromise.future, Duration.Inf)) match {
                case Success(result) =>
                  var index = 0
                  while (index < result.length) {
                    if (result(index) == StatusCode.HARD_SPLIT) {
                      pushMergedDataCallback.addSplitPartition(index, result(index))
                    }
                    index += 1
                  }
                  // Only primary data enable replication will push data to replica
                  Option(CongestionController.instance()) match {
                    case Some(congestionController) =>
                      val userCongested =
                        fileWriters
                          .find(_ != null)
                          .map(_.getUserCongestionControlContext)
                          .exists(congestionController.isUserCongested)
                      if (userCongested) {
                        // Check whether primary congest the data though the replicas doesn't congest
                        // it(the response is empty)
                        pushMergedDataCallback.onSuccess(
                          StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED)
                      } else {
                        if (replicaReason == StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue) {
                          pushMergedDataCallback.onSuccess(
                            StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED)
                        } else {
                          pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
                        }
                      }
                    case _ =>
                      if (replicaReason == StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue) {
                        pushMergedDataCallback.onSuccess(
                          StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED)
                      } else {
                        pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
                      }
                  }
                case Failure(e) => callbackWithTimer.onFailure(e)
              }
            }

            override def onFailure(e: Throwable): Unit = {
              logError(s"PushMergedData replicate failed for partitionLocation: $location", e)
              // 1. Throw PUSH_DATA_WRITE_FAIL_REPLICA by replica peer worker
              // 2. Throw PUSH_DATA_TIMEOUT_REPLICA by TransportResponseHandler
              // 3. Throw IOException by channel, convert to PUSH_DATA_CONNECTION_EXCEPTION_REPLICA
              if (e.getMessage.startsWith(StatusCode.PUSH_DATA_WRITE_FAIL_REPLICA.name())) {
                workerSource.incCounter(WorkerSource.REPLICATE_DATA_WRITE_FAIL_COUNT)
                pushMergedDataCallback.onFailure(e)
              } else if (e.getMessage.startsWith(StatusCode.PUSH_DATA_TIMEOUT_REPLICA.name())) {
                workerSource.incCounter(WorkerSource.REPLICATE_DATA_TIMEOUT_COUNT)
                pushMergedDataCallback.onFailure(e)
              } else if (ExceptionUtils.connectFail(e.getMessage)) {
                workerSource.incCounter(WorkerSource.REPLICATE_DATA_CONNECTION_EXCEPTION_COUNT)
                pushMergedDataCallback.onFailure(
                  new CelebornIOException(StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_REPLICA))
              } else {
                workerSource.incCounter(WorkerSource.REPLICATE_DATA_FAIL_NON_CRITICAL_CAUSE_COUNT)
                pushMergedDataCallback.onFailure(
                  new CelebornIOException(StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA))
              }
            }
          }

          try {
            val client = getReplicateClient(peer.getHost, peer.getReplicatePort, location.getId)
            val newPushMergedData = new PushMergedData(
              PartitionLocation.Mode.REPLICA.mode(),
              shuffleKey,
              pushMergedData.partitionUniqueIds,
              batchOffsets,
              pushMergedData.body)
            client.pushMergedData(
              newPushMergedData,
              shufflePushDataTimeout.get(shuffleKey),
              wrappedCallback)
          } catch {
            case e: Exception =>
              pushMergedData.body().release()
              unavailablePeers.put(peerWorker, System.currentTimeMillis())
              workerSource.incCounter(WorkerSource.REPLICATE_DATA_CREATE_CONNECTION_FAIL_COUNT)
              logError(
                s"PushMergedData replication failed during connecting peer for partitionLocation: $location",
                e)
              pushMergedDataCallback.onFailure(
                new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA))
          }
        }
      })
      writeLocalData(
        fileWriters,
        body,
        shuffleKey,
        isPrimary,
        Some(batchOffsets),
        writePromise,
        hardSplitIndexes)
    } else {
      // The codes here could be executed if
      // 1. the client doesn't enable push data to the replica, the primary worker could hit here
      // 2. the client enables push data to the replica, and the replica worker could hit here
      writeLocalData(
        fileWriters,
        body,
        shuffleKey,
        isPrimary,
        Some(batchOffsets),
        writePromise,
        hardSplitIndexes)
      Try(Await.result(writePromise.future, Duration.Inf)) match {
        case Success(result) =>
          var index = 0
          while (index < result.length) {
            if (result(index) == StatusCode.HARD_SPLIT) {
              pushMergedDataCallback.addSplitPartition(index, result(index))
            }
            index += 1
          }
          Option(CongestionController.instance()) match {
            case Some(congestionController) =>
              val userCongested =
                fileWriters
                  .find(_ != null)
                  .map(_.getUserCongestionControlContext)
                  .exists(congestionController.isUserCongested)
              if (userCongested) {
                if (isPrimary) {
                  pushMergedDataCallback.onSuccess(StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED)
                } else {
                  pushMergedDataCallback.onSuccess(StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED)
                }
              } else {
                pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
              }
            case _ =>
              pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
          }
        case Failure(e) => pushMergedDataCallback.onFailure(e)
      }
    }
  }