in client/src/main/scala/org/apache/celeborn/client/CommitManager.scala [93:170]
def start(): Unit = {
lifecycleManager.registerWorkerStatusListener(new ShutdownWorkerListener)
batchHandleCommitPartition = batchHandleCommitPartitionSchedulerThread.map {
_.scheduleAtFixedRate(
new Runnable {
override def run(): Unit = {
committedPartitionInfo.asScala.foreach { case (shuffleId, shuffleCommittedInfo) =>
batchHandleCommitPartitionExecutors.submit {
new Runnable {
val commitHandler = getCommitHandler(shuffleId)
override def run(): Unit = {
var workerToRequests: Map[WorkerInfo, collection.Set[PartitionLocation]] = null
shuffleCommittedInfo.synchronized {
workerToRequests =
commitHandler.batchUnhandledRequests(shuffleId, shuffleCommittedInfo)
// when batch commit thread starts to commit these requests, we should increment inFlightNum,
// then stage/partition end would be able to recognize all requests are over.
commitHandler.incrementInFlightNum(shuffleCommittedInfo, workerToRequests)
}
if (workerToRequests.nonEmpty) {
val commitFilesFailedWorkers = new ShuffleFailedWorkers()
val parallelism =
Math.min(workerToRequests.size, conf.clientRpcMaxParallelism)
try {
ThreadUtils.parmap(
workerToRequests.to,
"CommitFiles",
parallelism) {
case (worker, requests) =>
val workerInfo =
lifecycleManager.shuffleAllocatedWorkers
.get(shuffleId)
.asScala
.find(_._1.equals(worker))
.get
._1
val primaryIds =
requests
.filter(_.getMode == PartitionLocation.Mode.PRIMARY)
.map(_.getUniqueId)
.toList
.asJava
val replicaIds =
requests
.filter(_.getMode == PartitionLocation.Mode.REPLICA)
.map(_.getUniqueId)
.toList
.asJava
commitHandler.commitFiles(
appUniqueId,
shuffleId,
shuffleCommittedInfo,
workerInfo,
primaryIds,
replicaIds,
commitFilesFailedWorkers)
}
lifecycleManager.workerStatusTracker.recordWorkerFailure(
commitFilesFailedWorkers)
} finally {
// when batch commit thread ends, we need decrementInFlightNum
commitHandler.decrementInFlightNum(shuffleCommittedInfo, workerToRequests)
}
}
}
}
}
}
}
},
0,
batchHandleCommitPartitionRequestInterval,
TimeUnit.MILLISECONDS)
}
}