in client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala [226:404]
def doParallelCommitFiles(
shuffleId: Int,
shuffleCommittedInfo: ShuffleCommittedInfo,
params: ArrayBuffer[CommitFilesParam],
commitFilesFailedWorkers: ShuffleFailedWorkers): Unit = {
def retryCommitFiles(status: CommitFutureWithStatus, currentTime: Long): Unit = {
status.retriedTimes = status.retriedTimes + 1
status.startTime = currentTime
val mockFailure = status.message.mockFailure && (status.retriedTimes < maxRetries)
val msg =
status.message.copy(mockFailure = mockFailure)
status.future = commitFiles(
status.workerInfo,
msg)
}
def createFailResponse(status: CommitFutureWithStatus): CommitFilesResponse = {
CommitFilesResponse(
StatusCode.REQUEST_FAILED,
List.empty.asJava,
List.empty.asJava,
status.message.primaryIds,
status.message.replicaIds)
}
def processResponse(res: CommitFilesResponse, worker: WorkerInfo): Unit = {
shuffleCommittedInfo.synchronized {
// record committed partitionIds
res.committedPrimaryIds.asScala.foreach {
case commitPrimaryId =>
val partitionUniqueIdList = shuffleCommittedInfo.committedPrimaryIds.computeIfAbsent(
Utils.splitPartitionLocationUniqueId(commitPrimaryId)._1,
(k: Int) => new util.ArrayList[String]())
partitionUniqueIdList.add(commitPrimaryId)
}
res.committedReplicaIds.asScala.foreach {
case commitReplicaId =>
val partitionUniqueIdList = shuffleCommittedInfo.committedReplicaIds.computeIfAbsent(
Utils.splitPartitionLocationUniqueId(commitReplicaId)._1,
(k: Int) => new util.ArrayList[String]())
partitionUniqueIdList.add(commitReplicaId)
}
// record committed partitions storage hint and disk hint
shuffleCommittedInfo.committedPrimaryStorageInfos.putAll(res.committedPrimaryStorageInfos)
shuffleCommittedInfo.committedReplicaStorageInfos.putAll(res.committedReplicaStorageInfos)
// record failed partitions
shuffleCommittedInfo.failedPrimaryPartitionIds.putAll(
res.failedPrimaryIds.asScala.map((_, worker)).toMap.asJava)
shuffleCommittedInfo.failedReplicaPartitionIds.putAll(
res.failedReplicaIds.asScala.map((_, worker)).toMap.asJava)
shuffleCommittedInfo.committedMapIdBitmap.putAll(res.committedMapIdBitMap)
totalWritten.add(res.totalWritten)
fileCount.add(res.fileCount)
shuffleCommittedInfo.currentShuffleFileCount.add(res.fileCount)
}
}
val futures = new LinkedBlockingQueue[CommitFutureWithStatus]()
val startTime = System.currentTimeMillis()
val outFutures = params.filter(param =>
!CollectionUtils.isEmpty(param.primaryIds) ||
!CollectionUtils.isEmpty(param.replicaIds)) map { param =>
Future {
val msg = CommitFiles(
appUniqueId,
shuffleId,
param.primaryIds,
param.replicaIds,
getMapperAttempts(shuffleId),
commitEpoch.incrementAndGet(),
mockCommitFilesFailure)
val future = commitFiles(param.worker, msg)
futures.add(CommitFutureWithStatus(future, msg, param.worker, 1, startTime))
}(ec)
}
val cbf =
implicitly[
CanBuildFrom[ArrayBuffer[Future[Boolean]], Boolean, ArrayBuffer[Boolean]]]
val futureSeq = Future.sequence(outFutures)(cbf, ec)
awaitResult(futureSeq, Duration.Inf)
val timeout = clientRpcCommitFilesAskTimeout.duration.toMillis
var remainingTime = timeout * maxRetries
val delta = 50
while (remainingTime >= 0 && !futures.isEmpty) {
val currentTime = System.currentTimeMillis()
val iter = futures.iterator()
while (iter.hasNext) {
val status = iter.next()
val worker = status.workerInfo
if (status.future.isCompleted) {
status.future.value.get match {
case scala.util.Success(res) =>
res.status match {
case StatusCode.SUCCESS | StatusCode.PARTIAL_SUCCESS | StatusCode.SHUFFLE_NOT_REGISTERED | StatusCode.REQUEST_FAILED | StatusCode.WORKER_EXCLUDED | StatusCode.COMMIT_FILE_EXCEPTION =>
if (res.status == StatusCode.SUCCESS) {
logDebug(s"Request commitFiles return ${res.status} for " +
s"${Utils.makeShuffleKey(appUniqueId, shuffleId)} from worker ${worker.readableAddress()}")
} else {
logWarning(s"Request commitFiles return ${res.status} for " +
s"${Utils.makeShuffleKey(appUniqueId, shuffleId)} from worker ${worker.readableAddress()}")
if (res.status != StatusCode.WORKER_EXCLUDED) {
commitFilesFailedWorkers.put(worker, (res.status, System.currentTimeMillis()))
}
}
processResponse(res, worker)
iter.remove()
case StatusCode.COMMIT_FILES_MOCK_FAILURE =>
if (status.retriedTimes < maxRetries) {
logError(s"Request commitFiles return ${res.status} for " +
s"${Utils.makeShuffleKey(appUniqueId, shuffleId)} for ${status.retriedTimes}/$maxRetries, will retry")
retryCommitFiles(status, currentTime)
} else {
logError(
s"Request commitFiles return ${StatusCode.COMMIT_FILES_MOCK_FAILURE} for " +
s"${Utils.makeShuffleKey(appUniqueId, shuffleId)} for ${status.retriedTimes}/$maxRetries, will not retry")
val res = createFailResponse(status)
processResponse(res, worker)
iter.remove()
}
case _ =>
logError(s"Should never reach here! commit files response status ${res.status}")
}
case scala.util.Failure(e) =>
if (status.retriedTimes < maxRetries) {
logError(
s"Ask worker(${worker.readableAddress()}) CommitFiles for $shuffleId failed" +
s" (attempt ${status.retriedTimes}/$maxRetries), will retry.",
e)
retryCommitFiles(status, currentTime)
} else {
logError(
s"Ask worker(${worker.readableAddress()}) CommitFiles for $shuffleId failed" +
s" (attempt ${status.retriedTimes}/$maxRetries), will not retry.",
e)
val res = createFailResponse(status)
processResponse(res, status.workerInfo)
iter.remove()
}
}
} else if (currentTime - status.startTime > timeout) {
if (status.retriedTimes < maxRetries) {
logError(
s"Ask worker(${worker.readableAddress()}) CommitFiles for $shuffleId failed because of Timeout" +
s" (attempt ${status.retriedTimes}/$maxRetries), will retry.")
retryCommitFiles(status, currentTime)
} else {
logError(
s"Ask worker(${worker.readableAddress()}) CommitFiles for $shuffleId failed because of Timeout" +
s" (attempt ${status.retriedTimes}/$maxRetries), will not retry.")
}
}
}
if (!futures.isEmpty) {
Thread.sleep(delta)
}
remainingTime -= delta
}
val iter = futures.iterator()
while (iter.hasNext) {
val status = iter.next()
logError(
s"Ask worker(${status.workerInfo.readableAddress()}) CommitFiles for $shuffleId timed out")
val res = createFailResponse(status)
processResponse(res, status.workerInfo)
iter.remove()
}
}