override def getWriter[K, V]()

in src/main/scala/org/apache/spark/shuffle/RssShuffleManager.scala [149:312]


  override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] = {
    logInfo(s"getWriter: Use ShuffleManager: ${this.getClass().getSimpleName()}, $handle, mapId: $mapId, stageId: ${context.stageId()}, shuffleId: ${handle.shuffleId}")

    handle match {
      case rssShuffleHandle: RssShuffleHandle[K@unchecked, V@unchecked, _] => {
        val writerQueueSize = conf.get(RssOpts.writerQueueSize)

        val mapInfo = new AppTaskAttemptId(
          conf.getAppId,
          rssShuffleHandle.appAttempt,
          handle.shuffleId,
          mapId,
          context.taskAttemptId()
        )

        logDebug( s"getWriter $mapInfo" )

        createShuffleClientStageMetricsIfNeeded( rssShuffleHandle )

        val serializer = rssShuffleHandle.dependency.serializer
        val maxWaitMillis = conf.get( RssOpts.maxWaitTime )
        val useConnectionPool = conf.get(RssOpts.useConnectionPool)
        val rssMapsPerSplit = conf.get(RssOpts.mapsPerSplit)
        var rssNumSplits = Math.ceil(rssShuffleHandle.numMaps.toDouble/rssMapsPerSplit.toDouble).toInt
        val rssMinSplits = conf.get(RssOpts.minSplits)
        val rssMaxSplits = conf.get(RssOpts.maxSplits)
        if (rssNumSplits < rssMinSplits) {
          rssNumSplits = rssMinSplits
        } else if (rssNumSplits > rssMaxSplits) {
          rssNumSplits = rssMaxSplits
        }
        val shuffleWriteConfig = new ShuffleWriteConfig(rssNumSplits.toShort)
        val rssReplicas = conf.get(RssOpts.replicas)
        if (rssReplicas <= 0) {
          throw new RssException(s"Invalid config value for ${RssOpts.replicas.key}: $rssReplicas")
        }
        val rssServers: ServerList = ServerConnectionStringCache.getInstance().getServerList(rssShuffleHandle.getServerList)
        val serverReplicationGroups = ServerReplicationGroupUtil.createReplicationGroups(rssServers.getSevers, rssReplicas)

        val serverConnectionResolver = new ServerConnectionStringResolver {
          override def resolveConnection(serverId: String): ServerDetail = {
            val serverDetailInShuffleHandle = rssShuffleHandle.getServerList.getSeverDetail(serverId)
            if (serverDetailInShuffleHandle == null) {
              throw new FetchFailedException(
                bmAddress = RssUtils.createMapTaskDummyBlockManagerId(mapInfo.getMapId, mapInfo.getTaskAttemptId),
                shuffleId = rssShuffleHandle.shuffleId,
                mapId = -1,
                reduceId = 0,
                message = s"Failed to get server detail for $serverId from shuffle handle: $rssShuffleHandle")
            }
            // random sleep some time to avoid request spike on service registry
            val random = new Random()
            val randomWaitMillis = random.nextInt(pollInterval)
            ThreadUtils.sleep(randomWaitMillis)

            val lookupResult = executeWithServiceRegistry(serviceRegistry =>
              serviceRegistry.lookupServers(dataCenter, cluster, util.Arrays.asList(serverId)))

            if (lookupResult == null) {
              throw new RssServerResolveException(s"Got null when looking up server for $serverId")
            }
            // close service registry
            if (lookupResult.size() != 1) {
              throw new RssInvalidStateException(s"Invalid result $lookupResult when looking up server for $serverId")
            }
            val refreshedServer: ServerDetail = lookupResult.get(0)
            // add refreshed server into cache so future server lookup from the cache will get latest server.
            ServerConnectionStringCache.getInstance().updateServer(serverId, refreshedServer)
            if (!refreshedServer.equals(serverDetailInShuffleHandle)) {
              throw new FetchFailedException(
                bmAddress = RssUtils.createMapTaskDummyBlockManagerId(mapInfo.getMapId, mapInfo.getTaskAttemptId),
                shuffleId = rssShuffleHandle.shuffleId,
                mapId = -1,
                reduceId = 0,
                message = s"Detected server restart, current server: $refreshedServer, previous server: $serverDetailInShuffleHandle")
            }
            refreshedServer
          }
        }
        val serverConnectionRefresher = new ServerConnectionCacheUpdateRefresher(serverConnectionResolver, ServerConnectionStringCache.getInstance())

        val writerAsyncFinish = conf.get(RssOpts.writerAsyncFinish)
        val finishUploadAck = !writerAsyncFinish

        RetryUtils.retry(pollInterval, pollInterval * 10, maxWaitMillis, "create write client", new Supplier[ShuffleWriter[K, V]] {
          override def get(): ShuffleWriter[K, V] = {
            var writeClient: MultiServerWriteClient =
              if (writerQueueSize == 0) {
                logInfo(s"Use replicated sync writer, $rssNumSplits splits, ${rssShuffleHandle.partitionFanout} partition fanout, $serverReplicationGroups, finishUploadAck: $finishUploadAck")
                new MultiServerSyncWriteClient(
                  serverReplicationGroups,
                  rssShuffleHandle.partitionFanout,
                  networkTimeoutMillis,
                  maxWaitMillis,
                  serverConnectionRefresher,
                  finishUploadAck,
                  useConnectionPool,
                  rssShuffleHandle.user,
                  rssShuffleHandle.appId,
                  rssShuffleHandle.appAttempt,
                  shuffleWriteConfig)
              } else {
                val maxThreads = conf.get(RssOpts.writerMaxThreads)
                val serverThreadRatio = 8.0
                val numThreadsBasedOnShuffleServers = Math.ceil(rssShuffleHandle.rssServers.length.toDouble / serverThreadRatio)
                val numThreads = Math.min(numThreadsBasedOnShuffleServers, maxThreads).toInt
                logInfo(s"Use replicated async writer with queue size $writerQueueSize threads $numThreads, $rssNumSplits splits, ${rssShuffleHandle.partitionFanout} partition fanout, $serverReplicationGroups, finishUploadAck: $finishUploadAck")
                new MultiServerAsyncWriteClient(
                  serverReplicationGroups,
                  rssShuffleHandle.partitionFanout,
                  networkTimeoutMillis,
                  maxWaitMillis,
                  serverConnectionRefresher,
                  finishUploadAck,
                  useConnectionPool,
                  writerQueueSize,
                  numThreads,
                  rssShuffleHandle.user,
                  rssShuffleHandle.appId,
                  rssShuffleHandle.appAttempt,
                  shuffleWriteConfig)
              }

            val createLazyClientConnection: Boolean = conf.get(RssOpts.enableLazyMapperClientConnection)
            if (createLazyClientConnection) {
              writeClient = new LazyWriteClient(writeClient, mapInfo,
                rssShuffleHandle.numMaps, rssShuffleHandle.dependency.partitioner.numPartitions,
                pollInterval, maxWaitMillis)
            }

            try {
              writeClient.connect()

              val compressionLevel = if (Compression.COMPRESSION_CODEC_ZSTD.equals(RssOpts.compression)) {
                conf.get(RssOpts.zstdCompressionLevel)
              } else {
                0
              }

              new RssShuffleWriter(
                rssShuffleHandle.user,
                new ServerList(rssShuffleHandle.rssServers.map(_.toServerDetail()).toArray),
                writeClient,
                mapInfo,
                rssShuffleHandle.numMaps,
                serializer,
                conf.get(RssOpts.compression),
                CompressionOptions(compressionLevel),
                bufferOptions,
                rssShuffleHandle.dependency,
                shuffleClientStageMetrics,
                context.taskMetrics(),
                conf)
            } catch {
              case ex: Throwable => {
                ExceptionUtils.closeWithoutException(writeClient)
                throw ex
              }
            }
          }
        })
      }
    }
  }