def run()

in src/main/scala/org/apache/spark/shuffle/rss/RssStressTool.scala [171:259]


  def run(): Unit = {
    logInfo(String.format("Server root dirs: %s", StringUtils.join(serverRootDirs, ':')))

    // Generate test values to use
    val testValues = new util.ArrayList[String]
    testValues.add(null)
    testValues.add("")
    while (testValues.size < numTestValues) {
      val ch = ('a' + random.nextInt(26)).toChar
      val repeats = random.nextInt(maxTestValueLen)
      val str = StringUtils.repeat(ch, repeats)
      testValues.add(str)
    }

    // Create map task threads to write shuffle data
    val simulatedNumberOfAttemptsForMappers = new util.ArrayList[Integer]
    val fetchTaskAttemptIds = new util.ArrayList[Long]
    var i = startMapId
    while (i <= endMapId) {
      val value = random.nextInt(numMapTaskRetries) + 1
      simulatedNumberOfAttemptsForMappers.add(value)
      i += 1
    }
    var mapId = startMapId
    while (mapId <= endMapId) {
      val index = mapId - startMapId
      val appMapId = new AppMapId(appShuffleId.getAppId, appShuffleId.getAppAttempt, appShuffleId.getShuffleId, mapId)
      val simulatedNumberOfAttempts = simulatedNumberOfAttemptsForMappers.get(index)
      val thread = new Thread(new Runnable {
        override def run(): Unit = {
          var attempt = 1
          while (attempt <= simulatedNumberOfAttempts) {
            val taskAttemptId = taskAttemptIdSeed.getAndIncrement
            val isLastTaskAttempt = attempt == simulatedNumberOfAttempts
            if (isLastTaskAttempt) {
              fetchTaskAttemptIds.synchronized{
                fetchTaskAttemptIds.add(taskAttemptId)
              }
            }
            simulateMapperTask(testValues, appMapId, taskAttemptId, isLastTaskAttempt)
            attempt += 1;
          }
        }
      })
      thread.setName(String.format("[Map Thread %s]", appMapId))
      thread.setUncaughtExceptionHandler(new Thread.UncaughtExceptionHandler() {
        override def uncaughtException(t: Thread, e: Throwable): Unit = {
          logError(String.format("Mapper thread %s got exception", t.getName), e)
          mapThreadErrors.incrementAndGet
        }
      })
      allMapThreads.add(thread)

      mapId += 1;
    }

    // Start map task threads
    allMapThreads.asScala.foreach((t: Thread) => t.start())
    // Wait for map tasks to finish
    allMapThreads.asScala.foreach((t: Thread) => t.join())

    if (mapThreadErrors.get > 0) {
      throw new RuntimeException("Number of errors in map threads: " + mapThreadErrors)
    }

    // Read shuffle data
    val allReadData = new ListBuffer[Product2[String, String]]()
    (0 until numPartitions).foreach(p => {
      allReadData.appendAll(readShuffleData(p))
    })

    logInfo(s"Total read records: ${allReadData.size}")

    allReadData.foreach(t=>{
      if (!testValues.contains(t._1)) {
        throw new RuntimeException(s"Detected failure, read unexpected record key: ${t._1}")
      }
      if (!testValues.contains(t._2)) {
        throw new RuntimeException(s"Detected failure, read unexpected record value: ${t._2}")
      }
    })

    val expectedNumRecords = testValues.size() * numMaps
    if (allReadData.size != expectedNumRecords) {
      throw new RuntimeException(s"Detected failure, expected records: $expectedNumRecords, actual records: ${allReadData.size}")
    }

    logInfo("Test run finished successfully")
  }