in src/main/scala/org/apache/spark/shuffle/rss/RssStressTool.scala [171:259]
def run(): Unit = {
logInfo(String.format("Server root dirs: %s", StringUtils.join(serverRootDirs, ':')))
// Generate test values to use
val testValues = new util.ArrayList[String]
testValues.add(null)
testValues.add("")
while (testValues.size < numTestValues) {
val ch = ('a' + random.nextInt(26)).toChar
val repeats = random.nextInt(maxTestValueLen)
val str = StringUtils.repeat(ch, repeats)
testValues.add(str)
}
// Create map task threads to write shuffle data
val simulatedNumberOfAttemptsForMappers = new util.ArrayList[Integer]
val fetchTaskAttemptIds = new util.ArrayList[Long]
var i = startMapId
while (i <= endMapId) {
val value = random.nextInt(numMapTaskRetries) + 1
simulatedNumberOfAttemptsForMappers.add(value)
i += 1
}
var mapId = startMapId
while (mapId <= endMapId) {
val index = mapId - startMapId
val appMapId = new AppMapId(appShuffleId.getAppId, appShuffleId.getAppAttempt, appShuffleId.getShuffleId, mapId)
val simulatedNumberOfAttempts = simulatedNumberOfAttemptsForMappers.get(index)
val thread = new Thread(new Runnable {
override def run(): Unit = {
var attempt = 1
while (attempt <= simulatedNumberOfAttempts) {
val taskAttemptId = taskAttemptIdSeed.getAndIncrement
val isLastTaskAttempt = attempt == simulatedNumberOfAttempts
if (isLastTaskAttempt) {
fetchTaskAttemptIds.synchronized{
fetchTaskAttemptIds.add(taskAttemptId)
}
}
simulateMapperTask(testValues, appMapId, taskAttemptId, isLastTaskAttempt)
attempt += 1;
}
}
})
thread.setName(String.format("[Map Thread %s]", appMapId))
thread.setUncaughtExceptionHandler(new Thread.UncaughtExceptionHandler() {
override def uncaughtException(t: Thread, e: Throwable): Unit = {
logError(String.format("Mapper thread %s got exception", t.getName), e)
mapThreadErrors.incrementAndGet
}
})
allMapThreads.add(thread)
mapId += 1;
}
// Start map task threads
allMapThreads.asScala.foreach((t: Thread) => t.start())
// Wait for map tasks to finish
allMapThreads.asScala.foreach((t: Thread) => t.join())
if (mapThreadErrors.get > 0) {
throw new RuntimeException("Number of errors in map threads: " + mapThreadErrors)
}
// Read shuffle data
val allReadData = new ListBuffer[Product2[String, String]]()
(0 until numPartitions).foreach(p => {
allReadData.appendAll(readShuffleData(p))
})
logInfo(s"Total read records: ${allReadData.size}")
allReadData.foreach(t=>{
if (!testValues.contains(t._1)) {
throw new RuntimeException(s"Detected failure, read unexpected record key: ${t._1}")
}
if (!testValues.contains(t._2)) {
throw new RuntimeException(s"Detected failure, read unexpected record value: ${t._2}")
}
})
val expectedNumRecords = testValues.size() * numMaps
if (allReadData.size != expectedNumRecords) {
throw new RuntimeException(s"Detected failure, expected records: $expectedNumRecords, actual records: ${allReadData.size}")
}
logInfo("Test run finished successfully")
}