in src/main/scala/org/apache/spark/shuffle/RssShuffleManager.scala [78:146]
override def registerShuffle[K, V, C](shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
// RSS does not support speculation yet, due to the random task attempt ids (finished map task attempt id not always increasing).
// We will fall back to SortShuffleManager if speculation is configured to true.
val useSpeculation = conf.getBoolean("spark.speculation", false)
if (useSpeculation) {
throw new RssException("Do not support speculation in Remote Shuffle Service")
}
logInfo(s"Use ShuffleManager: ${this.getClass().getSimpleName()}")
val numPartitions = dependency.partitioner.numPartitions;
val sparkContext = getSparkContext
val user = sparkContext.sparkUser
val queue = conf.get(SparkYarnQueueConfigKey, "")
val appId = conf.getAppId
val appAttempt = sparkContext.applicationAttemptId.getOrElse("0")
val heartbeatClient = MultiServerHeartbeatClient.getInstance();
heartbeatClient.setAppContext(user, appId, appAttempt)
if (!heartbeatClient.hasServerConnectionRefresher) {
heartbeatClient.setServerConnectionRefresher(createServerConnectionRefresher4Heartbeat())
}
var rssServerSelectionResult: RssServerSelectionResult = null
val excludeHostsConfigValue = conf.get(RssOpts.excludeHosts)
val excludeHosts = excludeHostsConfigValue.split(",").filter(!_.isEmpty).distinct
rssServerSelectionResult = getRssServers(numMaps, numPartitions, excludeHosts)
val rssServers = rssServerSelectionResult.servers
logInfo(s"Selected ${rssServers.size} RSS servers for shuffle $shuffleId, maps: $numMaps, partitions: $numPartitions, replicas: ${rssServerSelectionResult.replicas}, partition fanout: ${rssServerSelectionResult.partitionFanout}, ${rssServers.mkString(",")}")
val tagMap = new java.util.HashMap[String, String]()
tagMap.put(RssDataCenterTagName, dataCenter)
tagMap.put(RssClusterTagName, cluster)
tagMap.put(UserMetricTagName, user)
M3Stats.getDefaultScope.tagged(tagMap).gauge(NumRssServersMetricName).update(rssServers.length)
RssSparkListener.registerSparkListenerOnlyOnce(sparkContext, () =>
new RssSparkListener(
user,
conf.getAppId,
appAttempt,
rssServerSelectionResult.servers.map(_.getConnectionString()),
networkTimeoutMillis))
val shuffleClientStageMetricsKey = new ShuffleClientStageMetricsKey(user, queue)
shuffleClientStageMetrics = new ShuffleClientStageMetrics(shuffleClientStageMetricsKey)
shuffleClientStageMetrics.getNumRegisterShuffle.inc(1)
shuffleClientStageMetrics.getNumMappers().recordValue(numMaps)
shuffleClientStageMetrics.getNumReducers().recordValue(numPartitions)
val dependencyInfo = s"numPartitions: ${dependency.partitioner.numPartitions}, " +
s"serializer: ${dependency.serializer.getClass().getSimpleName()}, " +
s"keyOrdering: ${dependency.keyOrdering}, " +
s"aggregator: ${dependency.aggregator}, " +
s"mapSideCombine: ${dependency.mapSideCombine}, " +
s"keyClassName: ${dependency.keyClassName}, " +
s"valueClassName: ${dependency.valueClassName}"
logInfo(s"registerShuffle: $appId, $appAttempt, $shuffleId, $numMaps, $dependencyInfo")
val rssServerHandles = rssServerSelectionResult.servers.map(t => new RssShuffleServerHandle(t.getServerId(), t.getRunningVersion(), t.getConnectionString())).toArray
new RssShuffleHandle(shuffleId, appId, appAttempt, numMaps, user, queue, dependency, rssServerHandles, rssServerSelectionResult.partitionFanout)
}