override def registerShuffle[K, V, C]()

in src/main/scala/org/apache/spark/shuffle/RssShuffleManager.scala [78:146]


  override def registerShuffle[K, V, C](shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
    // RSS does not support speculation yet, due to the random task attempt ids (finished map task attempt id not always increasing).
    // We will fall back to SortShuffleManager if speculation is configured to true.
    val useSpeculation = conf.getBoolean("spark.speculation", false)
    if (useSpeculation) {
      throw new RssException("Do not support speculation in Remote Shuffle Service")
    }

    logInfo(s"Use ShuffleManager: ${this.getClass().getSimpleName()}")

    val numPartitions = dependency.partitioner.numPartitions;

    val sparkContext = getSparkContext

    val user = sparkContext.sparkUser
    val queue = conf.get(SparkYarnQueueConfigKey, "")

    val appId = conf.getAppId
    val appAttempt = sparkContext.applicationAttemptId.getOrElse("0")

    val heartbeatClient = MultiServerHeartbeatClient.getInstance();
    heartbeatClient.setAppContext(user, appId, appAttempt)
    if (!heartbeatClient.hasServerConnectionRefresher) {
      heartbeatClient.setServerConnectionRefresher(createServerConnectionRefresher4Heartbeat())
    }

    var rssServerSelectionResult: RssServerSelectionResult = null

    val excludeHostsConfigValue = conf.get(RssOpts.excludeHosts)
    val excludeHosts = excludeHostsConfigValue.split(",").filter(!_.isEmpty).distinct

    rssServerSelectionResult = getRssServers(numMaps, numPartitions, excludeHosts)
    val rssServers = rssServerSelectionResult.servers
    logInfo(s"Selected ${rssServers.size} RSS servers for shuffle $shuffleId, maps: $numMaps, partitions: $numPartitions, replicas: ${rssServerSelectionResult.replicas}, partition fanout: ${rssServerSelectionResult.partitionFanout}, ${rssServers.mkString(",")}")

    val tagMap = new java.util.HashMap[String, String]()
    tagMap.put(RssDataCenterTagName, dataCenter)
    tagMap.put(RssClusterTagName, cluster)
    tagMap.put(UserMetricTagName, user)
    M3Stats.getDefaultScope.tagged(tagMap).gauge(NumRssServersMetricName).update(rssServers.length)

    RssSparkListener.registerSparkListenerOnlyOnce(sparkContext, () =>
      new RssSparkListener(
        user,
        conf.getAppId,
        appAttempt,
        rssServerSelectionResult.servers.map(_.getConnectionString()),
        networkTimeoutMillis))

    val shuffleClientStageMetricsKey = new ShuffleClientStageMetricsKey(user, queue)
    shuffleClientStageMetrics = new ShuffleClientStageMetrics(shuffleClientStageMetricsKey)

    shuffleClientStageMetrics.getNumRegisterShuffle.inc(1)
    shuffleClientStageMetrics.getNumMappers().recordValue(numMaps)
    shuffleClientStageMetrics.getNumReducers().recordValue(numPartitions)
    
    val dependencyInfo = s"numPartitions: ${dependency.partitioner.numPartitions}, " +
      s"serializer: ${dependency.serializer.getClass().getSimpleName()}, " +
      s"keyOrdering: ${dependency.keyOrdering}, " +
      s"aggregator: ${dependency.aggregator}, " +
      s"mapSideCombine: ${dependency.mapSideCombine}, " +
      s"keyClassName: ${dependency.keyClassName}, " +
      s"valueClassName: ${dependency.valueClassName}"

    logInfo(s"registerShuffle: $appId, $appAttempt, $shuffleId, $numMaps, $dependencyInfo")

    val rssServerHandles = rssServerSelectionResult.servers.map(t => new RssShuffleServerHandle(t.getServerId(), t.getRunningVersion(), t.getConnectionString())).toArray
    new RssShuffleHandle(shuffleId, appId, appAttempt, numMaps, user, queue, dependency, rssServerHandles, rssServerSelectionResult.partitionFanout)
  }