private def handleCommitFiles()

in worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala [324:588]


  private def handleCommitFiles(
      context: RpcCallContext,
      shuffleKey: String,
      primaryIds: jList[String],
      replicaIds: jList[String],
      mapAttempts: Array[Int],
      epoch: Long): Unit = {

    def alreadyCommitted(shuffleKey: String, epoch: Long): Boolean = {
      shuffleCommitInfos.contains(shuffleKey) && shuffleCommitInfos.get(shuffleKey).contains(epoch)
    }

    // Reply SHUFFLE_NOT_REGISTERED if shuffleKey does not exist AND the shuffle is not committed.
    // Say the first CommitFiles-epoch request succeeds in Worker and removed from partitionLocationInfo,
    // but for some reason the client thinks it's failed, the client will trigger again, so we should
    // check whether the CommitFiles-epoch is already committed here.
    if (!partitionLocationInfo.containsShuffle(shuffleKey) && !alreadyCommitted(
        shuffleKey,
        epoch)) {
      logError(s"Shuffle $shuffleKey doesn't exist!")
      context.reply(
        CommitFilesResponse(
          StatusCode.SHUFFLE_NOT_REGISTERED,
          List.empty.asJava,
          List.empty.asJava,
          primaryIds,
          replicaIds))
      return
    }

    val shuffleCommitTimeout = conf.workerShuffleCommitTimeout

    shuffleCommitInfos.putIfAbsent(shuffleKey, JavaUtils.newConcurrentHashMap[Long, CommitInfo]())
    val epochCommitMap = shuffleCommitInfos.get(shuffleKey)
    epochCommitMap.putIfAbsent(epoch, new CommitInfo(null, CommitInfo.COMMIT_NOTSTARTED))
    val commitInfo = epochCommitMap.get(epoch)

    def waitForCommitFinish(): Unit = {
      val delta = 100
      var times = 0
      while (delta * times < shuffleCommitTimeout) {
        commitInfo.synchronized {
          if (commitInfo.status == CommitInfo.COMMIT_FINISHED) {
            context.reply(commitInfo.response)
            return
          }
        }
        Thread.sleep(delta)
        times += 1
      }
    }

    commitInfo.synchronized {
      if (commitInfo.status == CommitInfo.COMMIT_FINISHED) {
        logInfo(s"${shuffleKey} CommitFinished, just return the response")
        context.reply(commitInfo.response)
        return
      } else if (commitInfo.status == CommitInfo.COMMIT_INPROCESS) {
        logInfo(s"${shuffleKey} CommitFiles inprogress, wait for finish")
        commitThreadPool.submit(new Runnable {
          override def run(): Unit = {
            waitForCommitFinish()
          }
        })
        return
      } else {
        logInfo(s"Start commitFiles for ${shuffleKey}")
        commitInfo.status = CommitInfo.COMMIT_INPROCESS
        workerSource.startTimer(WorkerSource.COMMIT_FILES_TIME, shuffleKey)
      }
    }

    // Update shuffleMapperAttempts
    shuffleMapperAttempts.putIfAbsent(shuffleKey, new AtomicIntegerArray(mapAttempts))
    val attempts = shuffleMapperAttempts.get(shuffleKey)
    if (mapAttempts.exists(_ != -1)) {
      attempts.synchronized {
        0 until attempts.length() foreach (idx => {
          if (mapAttempts(idx) != -1 && attempts.get(idx) == -1) {
            attempts.set(idx, mapAttempts(idx))
          }
        })
      }
    }

    // Use ConcurrentSet to avoid excessive lock contention.
    val committedPrimaryIds = ConcurrentHashMap.newKeySet[String]()
    val committedReplicaIds = ConcurrentHashMap.newKeySet[String]()
    val emptyFilePrimaryIds = ConcurrentHashMap.newKeySet[String]()
    val emptyFileReplicaIds = ConcurrentHashMap.newKeySet[String]()
    val failedPrimaryIds = ConcurrentHashMap.newKeySet[String]()
    val failedReplicaIds = ConcurrentHashMap.newKeySet[String]()
    val committedPrimaryStorageInfos = JavaUtils.newConcurrentHashMap[String, StorageInfo]()
    val committedReplicaStorageInfos = JavaUtils.newConcurrentHashMap[String, StorageInfo]()
    val committedMapIdBitMap = JavaUtils.newConcurrentHashMap[String, RoaringBitmap]()
    val partitionSizeList = new LinkedBlockingQueue[Long]()

    val primaryFuture =
      commitFiles(
        shuffleKey,
        primaryIds,
        committedPrimaryIds,
        emptyFilePrimaryIds,
        failedPrimaryIds,
        committedPrimaryStorageInfos,
        committedMapIdBitMap,
        partitionSizeList)
    val replicaFuture = commitFiles(
      shuffleKey,
      replicaIds,
      committedReplicaIds,
      emptyFileReplicaIds,
      failedReplicaIds,
      committedReplicaStorageInfos,
      committedMapIdBitMap,
      partitionSizeList,
      false)

    val future =
      if (primaryFuture != null && replicaFuture != null) {
        CompletableFuture.allOf(primaryFuture, replicaFuture)
      } else if (primaryFuture != null) {
        primaryFuture
      } else if (replicaFuture != null) {
        replicaFuture
      } else {
        null
      }

    def reply(): Unit = {
      // release slots before reply.
      val releasePrimaryLocations =
        partitionLocationInfo.removePrimaryPartitions(shuffleKey, primaryIds)
      val releaseReplicaLocations =
        partitionLocationInfo.removeReplicaPartitions(shuffleKey, replicaIds)
      logDebug(s"$shuffleKey remove" +
        s" slots count ${releasePrimaryLocations._2 + releaseReplicaLocations._2}")
      logDebug(s"CommitFiles result" +
        s" $committedPrimaryStorageInfos $committedReplicaStorageInfos")
      workerInfo.releaseSlots(shuffleKey, releasePrimaryLocations._1)
      workerInfo.releaseSlots(shuffleKey, releaseReplicaLocations._1)

      val committedPrimaryIdList = new jArrayList[String](committedPrimaryIds)
      val committedReplicaIdList = new jArrayList[String](committedReplicaIds)
      val failedPrimaryIdList = new jArrayList[String](failedPrimaryIds)
      val failedReplicaIdList = new jArrayList[String](failedReplicaIds)
      val committedPrimaryStorageAndDiskHintList =
        new jHashMap[String, StorageInfo](committedPrimaryStorageInfos)
      val committedReplicaStorageAndDiskHintList =
        new jHashMap[String, StorageInfo](committedReplicaStorageInfos)
      val committedMapIdBitMapList = new jHashMap[String, RoaringBitmap](committedMapIdBitMap)
      val totalSize = partitionSizeList.asScala.sum
      val fileCount = partitionSizeList.size()
      // reply
      val response =
        if (failedPrimaryIds.isEmpty && failedReplicaIds.isEmpty) {
          logInfo(
            s"CommitFiles for $shuffleKey success with " +
              s"${committedPrimaryIds.size()} committed primary partitions, " +
              s"${emptyFilePrimaryIds.size()} empty primary partitions, " +
              s"${failedPrimaryIds.size()} failed primary partitions, " +
              s"${committedReplicaIds.size()} committed replica partitions, " +
              s"${emptyFileReplicaIds.size()} empty replica partitions, " +
              s"${failedReplicaIds.size()} failed replica partitions.")
          CommitFilesResponse(
            StatusCode.SUCCESS,
            committedPrimaryIdList,
            committedReplicaIdList,
            List.empty.asJava,
            List.empty.asJava,
            committedPrimaryStorageAndDiskHintList,
            committedReplicaStorageAndDiskHintList,
            committedMapIdBitMapList,
            totalSize,
            fileCount)
        } else {
          logWarning(
            s"CommitFiles for $shuffleKey failed with " +
              s"${committedPrimaryIds.size()} committed primary partitions, " +
              s"${emptyFilePrimaryIds.size()} empty primary partitions, " +
              s"${failedPrimaryIds.size()} failed primary partitions, " +
              s"${committedReplicaIds.size()} committed replica partitions, " +
              s"${emptyFileReplicaIds.size()} empty replica partitions, " +
              s"${failedReplicaIds.size()} failed replica partitions.")
          CommitFilesResponse(
            StatusCode.PARTIAL_SUCCESS,
            committedPrimaryIdList,
            committedReplicaIdList,
            failedPrimaryIdList,
            failedReplicaIdList,
            committedPrimaryStorageAndDiskHintList,
            committedReplicaStorageAndDiskHintList,
            committedMapIdBitMapList,
            totalSize,
            fileCount)
        }
      if (testRetryCommitFiles) {
        Thread.sleep(5000)
      }
      commitInfo.synchronized {
        commitInfo.response = response
        commitInfo.status = CommitInfo.COMMIT_FINISHED
      }
      context.reply(response)

      workerSource.stopTimer(WorkerSource.COMMIT_FILES_TIME, shuffleKey)
    }

    if (future != null) {
      val result = new AtomicReference[CompletableFuture[Unit]]()

      val timeout = timer.newTimeout(
        new TimerTask {
          override def run(timeout: Timeout): Unit = {
            if (result.get() != null) {
              result.get().cancel(true)
              logWarning(s"After waiting $shuffleCommitTimeout s, cancel all commit file jobs.")
            }
          }
        },
        shuffleCommitTimeout,
        TimeUnit.SECONDS)

      result.set(future.handleAsync(
        new BiFunction[Void, Throwable, Unit] {
          override def apply(v: Void, t: Throwable): Unit = {
            if (null != t) {
              t match {
                case _: CancellationException =>
                  logWarning("While handling commitFiles, canceled.")
                case ee: ExecutionException =>
                  logError("While handling commitFiles, ExecutionException raised.", ee)
                case ie: InterruptedException =>
                  logWarning("While handling commitFiles, interrupted.")
                  Thread.currentThread().interrupt()
                  throw ie
                case _: TimeoutException =>
                  logWarning(s"While handling commitFiles, timeout after $shuffleCommitTimeout s.")
                case throwable: Throwable =>
                  logError("While handling commitFiles, exception occurs.", throwable)
              }
              commitInfo.synchronized {
                commitInfo.response = CommitFilesResponse(
                  StatusCode.COMMIT_FILE_EXCEPTION,
                  List.empty.asJava,
                  List.empty.asJava,
                  primaryIds,
                  replicaIds)

                commitInfo.status = CommitInfo.COMMIT_FINISHED
              }
            } else {
              // finish, cancel timeout job first.
              timeout.cancel()
              reply()
            }
          }
        },
        asyncReplyPool
      )) // should not use commitThreadPool in case of block by commit files.
    } else {
      // If both of two futures are null, then reply directly.
      reply()
    }
  }