in src/main/scala/org/apache/spark/shuffle/RssShuffleManager.scala [149:312]
override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] = {
logInfo(s"getWriter: Use ShuffleManager: ${this.getClass().getSimpleName()}, $handle, mapId: $mapId, stageId: ${context.stageId()}, shuffleId: ${handle.shuffleId}")
handle match {
case rssShuffleHandle: RssShuffleHandle[K@unchecked, V@unchecked, _] => {
val writerQueueSize = conf.get(RssOpts.writerQueueSize)
val mapInfo = new AppTaskAttemptId(
conf.getAppId,
rssShuffleHandle.appAttempt,
handle.shuffleId,
mapId,
context.taskAttemptId()
)
logDebug( s"getWriter $mapInfo" )
createShuffleClientStageMetricsIfNeeded( rssShuffleHandle )
val serializer = rssShuffleHandle.dependency.serializer
val maxWaitMillis = conf.get( RssOpts.maxWaitTime )
val useConnectionPool = conf.get(RssOpts.useConnectionPool)
val rssMapsPerSplit = conf.get(RssOpts.mapsPerSplit)
var rssNumSplits = Math.ceil(rssShuffleHandle.numMaps.toDouble/rssMapsPerSplit.toDouble).toInt
val rssMinSplits = conf.get(RssOpts.minSplits)
val rssMaxSplits = conf.get(RssOpts.maxSplits)
if (rssNumSplits < rssMinSplits) {
rssNumSplits = rssMinSplits
} else if (rssNumSplits > rssMaxSplits) {
rssNumSplits = rssMaxSplits
}
val shuffleWriteConfig = new ShuffleWriteConfig(rssNumSplits.toShort)
val rssReplicas = conf.get(RssOpts.replicas)
if (rssReplicas <= 0) {
throw new RssException(s"Invalid config value for ${RssOpts.replicas.key}: $rssReplicas")
}
val rssServers: ServerList = ServerConnectionStringCache.getInstance().getServerList(rssShuffleHandle.getServerList)
val serverReplicationGroups = ServerReplicationGroupUtil.createReplicationGroups(rssServers.getSevers, rssReplicas)
val serverConnectionResolver = new ServerConnectionStringResolver {
override def resolveConnection(serverId: String): ServerDetail = {
val serverDetailInShuffleHandle = rssShuffleHandle.getServerList.getSeverDetail(serverId)
if (serverDetailInShuffleHandle == null) {
throw new FetchFailedException(
bmAddress = RssUtils.createMapTaskDummyBlockManagerId(mapInfo.getMapId, mapInfo.getTaskAttemptId),
shuffleId = rssShuffleHandle.shuffleId,
mapId = -1,
reduceId = 0,
message = s"Failed to get server detail for $serverId from shuffle handle: $rssShuffleHandle")
}
// random sleep some time to avoid request spike on service registry
val random = new Random()
val randomWaitMillis = random.nextInt(pollInterval)
ThreadUtils.sleep(randomWaitMillis)
val lookupResult = executeWithServiceRegistry(serviceRegistry =>
serviceRegistry.lookupServers(dataCenter, cluster, util.Arrays.asList(serverId)))
if (lookupResult == null) {
throw new RssServerResolveException(s"Got null when looking up server for $serverId")
}
// close service registry
if (lookupResult.size() != 1) {
throw new RssInvalidStateException(s"Invalid result $lookupResult when looking up server for $serverId")
}
val refreshedServer: ServerDetail = lookupResult.get(0)
// add refreshed server into cache so future server lookup from the cache will get latest server.
ServerConnectionStringCache.getInstance().updateServer(serverId, refreshedServer)
if (!refreshedServer.equals(serverDetailInShuffleHandle)) {
throw new FetchFailedException(
bmAddress = RssUtils.createMapTaskDummyBlockManagerId(mapInfo.getMapId, mapInfo.getTaskAttemptId),
shuffleId = rssShuffleHandle.shuffleId,
mapId = -1,
reduceId = 0,
message = s"Detected server restart, current server: $refreshedServer, previous server: $serverDetailInShuffleHandle")
}
refreshedServer
}
}
val serverConnectionRefresher = new ServerConnectionCacheUpdateRefresher(serverConnectionResolver, ServerConnectionStringCache.getInstance())
val writerAsyncFinish = conf.get(RssOpts.writerAsyncFinish)
val finishUploadAck = !writerAsyncFinish
RetryUtils.retry(pollInterval, pollInterval * 10, maxWaitMillis, "create write client", new Supplier[ShuffleWriter[K, V]] {
override def get(): ShuffleWriter[K, V] = {
var writeClient: MultiServerWriteClient =
if (writerQueueSize == 0) {
logInfo(s"Use replicated sync writer, $rssNumSplits splits, ${rssShuffleHandle.partitionFanout} partition fanout, $serverReplicationGroups, finishUploadAck: $finishUploadAck")
new MultiServerSyncWriteClient(
serverReplicationGroups,
rssShuffleHandle.partitionFanout,
networkTimeoutMillis,
maxWaitMillis,
serverConnectionRefresher,
finishUploadAck,
useConnectionPool,
rssShuffleHandle.user,
rssShuffleHandle.appId,
rssShuffleHandle.appAttempt,
shuffleWriteConfig)
} else {
val maxThreads = conf.get(RssOpts.writerMaxThreads)
val serverThreadRatio = 8.0
val numThreadsBasedOnShuffleServers = Math.ceil(rssShuffleHandle.rssServers.length.toDouble / serverThreadRatio)
val numThreads = Math.min(numThreadsBasedOnShuffleServers, maxThreads).toInt
logInfo(s"Use replicated async writer with queue size $writerQueueSize threads $numThreads, $rssNumSplits splits, ${rssShuffleHandle.partitionFanout} partition fanout, $serverReplicationGroups, finishUploadAck: $finishUploadAck")
new MultiServerAsyncWriteClient(
serverReplicationGroups,
rssShuffleHandle.partitionFanout,
networkTimeoutMillis,
maxWaitMillis,
serverConnectionRefresher,
finishUploadAck,
useConnectionPool,
writerQueueSize,
numThreads,
rssShuffleHandle.user,
rssShuffleHandle.appId,
rssShuffleHandle.appAttempt,
shuffleWriteConfig)
}
val createLazyClientConnection: Boolean = conf.get(RssOpts.enableLazyMapperClientConnection)
if (createLazyClientConnection) {
writeClient = new LazyWriteClient(writeClient, mapInfo,
rssShuffleHandle.numMaps, rssShuffleHandle.dependency.partitioner.numPartitions,
pollInterval, maxWaitMillis)
}
try {
writeClient.connect()
val compressionLevel = if (Compression.COMPRESSION_CODEC_ZSTD.equals(RssOpts.compression)) {
conf.get(RssOpts.zstdCompressionLevel)
} else {
0
}
new RssShuffleWriter(
rssShuffleHandle.user,
new ServerList(rssShuffleHandle.rssServers.map(_.toServerDetail()).toArray),
writeClient,
mapInfo,
rssShuffleHandle.numMaps,
serializer,
conf.get(RssOpts.compression),
CompressionOptions(compressionLevel),
bufferOptions,
rssShuffleHandle.dependency,
shuffleClientStageMetrics,
context.taskMetrics(),
conf)
} catch {
case ex: Throwable => {
ExceptionUtils.closeWithoutException(writeClient)
throw ex
}
}
}
})
}
}
}