in client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala [532:768]
private def offerAndReserveSlots(
context: RegisterCallContext,
shuffleId: Int,
numMappers: Int,
numPartitions: Int,
partitionId: Int = -1,
isSegmentGranularityVisible: Boolean = false): Unit = {
val partitionType = getPartitionType(shuffleId)
registeringShuffleRequest.synchronized {
if (registeringShuffleRequest.containsKey(shuffleId)) {
// If same request already exists in the registering request list for the same shuffle,
// just register and return.
logDebug(s"[handleRegisterShuffle] request for shuffle $shuffleId exists, just register")
registeringShuffleRequest.get(shuffleId).add(context)
return
} else {
// If shuffle is registered, reply this shuffle's partition location and return.
// Else add this request to registeringShuffleRequest.
if (registeredShuffle.contains(shuffleId)) {
val rpcContext: RpcCallContext = context.context
partitionType match {
case PartitionType.MAP =>
processMapTaskReply(
shuffleId,
rpcContext,
partitionId,
getLatestLocs(shuffleId, p => p.getId == partitionId))
case PartitionType.REDUCE =>
if (rpcContext.isInstanceOf[LocalNettyRpcCallContext]) {
context.reply(RegisterShuffleResponse(
StatusCode.SUCCESS,
getLatestLocs(shuffleId, _ => true)))
} else {
val cachedMsg = registerShuffleResponseRpcCache.get(
shuffleId,
new Callable[ByteBuffer]() {
override def call(): ByteBuffer = {
rpcContext.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(
RegisterShuffleResponse(
StatusCode.SUCCESS,
getLatestLocs(shuffleId, _ => true)))
}
})
rpcContext.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg)
}
case _ =>
throw new UnsupportedOperationException(s"Not support $partitionType yet")
}
return
}
logInfo(s"New shuffle request, shuffleId $shuffleId, partitionType: $partitionType " +
s"numMappers: $numMappers, numReducers: $numPartitions.")
val set = new util.HashSet[RegisterCallContext]()
set.add(context)
registeringShuffleRequest.put(shuffleId, set)
}
}
def getLatestLocs(
shuffleId: Int,
partitionLocationFilter: PartitionLocation => Boolean): Array[PartitionLocation] = {
workerSnapshots(shuffleId)
.values()
.asScala
.flatMap(
_.getAllPrimaryLocationsWithMaxEpoch()
) // get the partition with latest epoch of each worker
.foldLeft(Map.empty[Int, PartitionLocation]) { (partitionLocationMap, partitionLocation) =>
partitionLocationMap.get(partitionLocation.getId) match {
case Some(existing) if existing.getEpoch >= partitionLocation.getEpoch =>
partitionLocationMap
case _ => partitionLocationMap + (partitionLocation.getId -> partitionLocation)
}
} // get the partition with latest epoch of all the partitions
.values
.filter(partitionLocationFilter)
.toArray
}
def processMapTaskReply(
shuffleId: Int,
context: RpcCallContext,
partitionId: Int,
partitionLocations: Array[PartitionLocation]): Unit = {
// if any partition location resource exist just reply
if (partitionLocations.size > 0) {
context.reply(RegisterShuffleResponse(StatusCode.SUCCESS, partitionLocations))
} else {
// request new resource for this task
changePartitionManager.handleRequestPartitionLocation(
ApplyNewLocationCallContext(context),
shuffleId,
partitionId,
-1,
null,
isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId))
}
}
// Reply to all RegisterShuffle request for current shuffle id.
def replyRegisterShuffle(response: PbRegisterShuffleResponse): Unit = {
registeringShuffleRequest.synchronized {
val serializedMsg: Option[ByteBuffer] = partitionType match {
case PartitionType.REDUCE =>
context.context match {
case remoteContext: RemoteNettyRpcCallContext =>
if (response.getStatus == StatusCode.SUCCESS.getValue) {
Option(remoteContext.nettyEnv.serialize(
response))
} else {
Option.empty
}
case _ => Option.empty
}
case _ => Option.empty
}
val locations = PbSerDeUtils.fromPbPackedPartitionLocationsPair(
response.getPackedPartitionLocationsPair)._1.asScala
registeringShuffleRequest.asScala
.get(shuffleId)
.foreach(_.asScala.foreach(context => {
partitionType match {
case PartitionType.MAP =>
if (response.getStatus == StatusCode.SUCCESS.getValue) {
val partitionLocations = locations.filter(_.getId == context.partitionId).toArray
processMapTaskReply(
shuffleId,
context.context,
context.partitionId,
partitionLocations)
} else {
// when register not success, need reply origin response,
// otherwise will lost original exception message
context.reply(response)
}
case PartitionType.REDUCE =>
if (context.context.isInstanceOf[
LocalNettyRpcCallContext] || response.getStatus != StatusCode.SUCCESS.getValue) {
context.reply(response)
} else {
registerShuffleResponseRpcCache.put(shuffleId, serializedMsg.get)
context.context.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(
serializedMsg.get)
}
case _ =>
throw new UnsupportedOperationException(s"Not support $partitionType yet")
}
}))
registeringShuffleRequest.remove(shuffleId)
}
}
// First, request to get allocated slots from Primary
val ids = new util.ArrayList[Integer](numPartitions)
(0 until numPartitions).foreach(idx => ids.add(Integer.valueOf(idx)))
val res = requestMasterRequestSlotsWithRetry(shuffleId, ids)
res.status match {
case StatusCode.REQUEST_FAILED =>
logInfo(s"OfferSlots RPC request failed for $shuffleId!")
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.REQUEST_FAILED, Array.empty))
return
case StatusCode.SLOT_NOT_AVAILABLE =>
logInfo(s"OfferSlots for $shuffleId failed!")
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.SLOT_NOT_AVAILABLE, Array.empty))
return
case StatusCode.SUCCESS =>
logDebug(s"OfferSlots for $shuffleId Success!Slots Info: ${res.workerResource}")
case StatusCode.WORKER_EXCLUDED =>
logInfo(s"OfferSlots for $shuffleId failed due to all workers be excluded!")
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.WORKER_EXCLUDED, Array.empty))
return
case _ => // won't happen
throw new UnsupportedOperationException()
}
// Reserve slots for each PartitionLocation. When response status is SUCCESS, WorkerResource
// won't be empty since primary will reply SlotNotAvailable status when reserved slots is empty.
val slots = res.workerResource
val candidatesWorkers = new util.HashSet(slots.keySet())
val connectFailedWorkers = new ShuffleFailedWorkers()
// Second, for each worker, try to initialize the endpoint.
setupEndpoints(slots.keySet(), shuffleId, connectFailedWorkers)
candidatesWorkers.removeAll(connectFailedWorkers.asScala.keys.toList.asJava)
workerStatusTracker.recordWorkerFailure(connectFailedWorkers)
// If newly allocated from primary and can setup endpoint success, LifecycleManager should remove worker from
// the excluded worker list to improve the accuracy of the list.
workerStatusTracker.removeFromExcludedWorkers(candidatesWorkers)
// Third, for each slot, LifecycleManager should ask Worker to reserve the slot
// and prepare the pushing data env.
val reserveSlotsSuccess =
reserveSlotsWithRetry(
shuffleId,
candidatesWorkers,
slots,
updateEpoch = false,
isSegmentGranularityVisible)
// If reserve slots failed, clear allocated resources, reply ReserveSlotFailed and return.
if (!reserveSlotsSuccess) {
logError(s"reserve buffer for $shuffleId failed, reply to all.")
replyRegisterShuffle(RegisterShuffleResponse(StatusCode.RESERVE_SLOTS_FAILED, Array.empty))
} else {
if (log.isDebugEnabled()) {
logDebug(s"ReserveSlots for $shuffleId success with details:$slots!")
}
// Forth, register shuffle success, update status
val allocatedWorkers =
JavaUtils.newConcurrentHashMap[String, ShufflePartitionLocationInfo]()
slots.asScala.foreach { case (workerInfo, (primaryLocations, replicaLocations)) =>
val partitionLocationInfo = new ShufflePartitionLocationInfo(workerInfo)
partitionLocationInfo.addPrimaryPartitions(primaryLocations)
updateLatestPartitionLocations(shuffleId, primaryLocations)
partitionLocationInfo.addReplicaPartitions(replicaLocations)
allocatedWorkers.put(workerInfo.toUniqueId, partitionLocationInfo)
}
shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
registeredShuffle.add(shuffleId)
commitManager.registerShuffle(
shuffleId,
numMappers,
isSegmentGranularityVisible)
// Fifth, reply the allocated partition location to ShuffleClient.
logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
val allPrimaryPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).toArray
replyRegisterShuffle(RegisterShuffleResponse(
StatusCode.SUCCESS,
allPrimaryPartitionLocations))
}
}