in src/main/java/com/uber/rss/tools/StreamServerStressTool.java [297:506]
public void run() {
scheduler = Executors.newScheduledThreadPool(1);
scheduledMetricCollector = new ScheduledMetricCollector(serviceRegistry);
scheduledMetricCollector.scheduleCollectingMetrics(scheduler, ServiceRegistry.DEFAULT_DATA_CENTER, ServiceRegistry.DEFAULT_TEST_CLUSTER);
appShuffleId = new AppShuffleId(appId, appAttempt, 1);
storage = new ShuffleFileStorage();
// Start Remote Shuffle Service servers if no server hosts are provided
if (serverHosts.isEmpty()) {
for (int i = 0; i < numServers; i++) {
startNewServer();
}
// only simulate bad servers when using multiple replicas
if (numReplicas > 1) {
List<ServerReplicationGroup> serverReplicationGroups = ServerReplicationGroupUtil.createReplicationGroups(serverDetails, numReplicas);
serverReplicationGroups.forEach(t -> {
logger.info(String.format(String.format("Server replication group: %s", t)));
});
int halfSize = (int) Math.ceil((double) serverReplicationGroups.size() / 2.0);
List<ServerReplicationGroup> firstHalf = serverReplicationGroups.stream().limit(halfSize).collect(Collectors.toList());
List<ServerReplicationGroup> secondHalf = serverReplicationGroups.stream().skip(halfSize).collect(Collectors.toList());
serverIdsToShutdownDuringShuffleWrite.addAll(firstHalf.stream().map(t -> t.getServers().get(0).getServerId()).collect(Collectors.toList()));
serverIdsToShutdownDuringShuffleRead.addAll(secondHalf.stream().map(t -> t.getServers().get(0).getServerId()).collect(Collectors.toList()));
logger.info(String.format(String.format(
"Servers to shutdown during shuffle write: %s, servers to shutdown during shuffle read: %s",
serverIdsToShutdownDuringShuffleWrite, serverIdsToShutdownDuringShuffleRead)));
}
}
logger.info(String.format("Server root dirs: %s", StringUtils.join(serverRootDirs, ':')));
// Generate test values to use
List<byte[]> testValues = new ArrayList<>();
testValues.add(null);
testValues.add(new byte[0]);
testValues.add("".getBytes(StandardCharsets.UTF_8));
while (testValues.size() < numTestValues) {
char ch = (char)('a' + random.nextInt(26));
int repeats = random.nextInt(maxTestValueLen);
String str = StringUtils.repeat(ch, repeats);
testValues.add(str.getBytes(StandardCharsets.UTF_8));
}
// Simulate some map tasks which do not write any shuffle data
Set<Integer> mapIdsWritingEmptyData = new HashSet<>();
if (endMapId + 1 - startMapId > 2) {
mapIdsWritingEmptyData.add(startMapId + random.nextInt(endMapId + 1 - startMapId));
mapIdsWritingEmptyData.add(startMapId + random.nextInt(endMapId + 1 - startMapId));
}
// Create map task threads to upload shuffle data
RateCounter[] rateCounters = new RateCounter[endMapId + 1 - startMapId];
AtomicLong taskAttemptIdSeed = new AtomicLong();
List<Integer> simulatedNumberOfAttemptsForMappers = new ArrayList<>();
List<Long> fetchTaskAttemptIds = new ArrayList<>();
ConcurrentHashMap<Integer, AtomicLong> numPartitionRecords = new ConcurrentHashMap<>();
for (int i = startMapId; i <= endMapId; i++) {
int value = random.nextInt(numMapAttempts) + 1;
simulatedNumberOfAttemptsForMappers.add(value);
}
for (int i = startMapId; i <= endMapId; i++) {
final int mapId = i;
final int index = i - startMapId;
rateCounters[index] = new RateCounter(5000);
AppMapId appMapId = new AppMapId(appShuffleId.getAppId(), appShuffleId.getAppAttempt(), appShuffleId.getShuffleId(), mapId);
int simulatedNumberOfAttempts = simulatedNumberOfAttemptsForMappers.get(index);
Thread thread = new Thread(() -> {
for (int attempt = 1; attempt <= simulatedNumberOfAttempts; attempt++) {
long taskAttemptId = taskAttemptIdSeed.getAndIncrement();
boolean isLastTaskAttempt = attempt == simulatedNumberOfAttempts;
boolean simulateEmptyData = mapIdsWritingEmptyData.contains(mapId);
if (isLastTaskAttempt) {
synchronized (fetchTaskAttemptIds) {
fetchTaskAttemptIds.add(taskAttemptId);
}
}
simulateMapperTask(testValues,
appMapId,
taskAttemptId,
isLastTaskAttempt,
simulateEmptyData,
rateCounters[index],
numPartitionRecords);
}
});
thread.setName(String.format("[Map Thread %s]", appMapId));
thread.setUncaughtExceptionHandler(new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
logger.error(String.format("Mapper thread %s got exception", t.getName()), e);
mapThreadErrors.incrementAndGet();
}
});
allMapThreads.add(thread);
}
long uploadStartTime = System.currentTimeMillis();
// Start map task threads
allMapThreads.forEach(t->t.start());
// Wait for map tasks to finish
allMapThreads.forEach(t-> {
try {
t.join();
} catch (InterruptedException e) {
M3Stats.addException(e, M3Stats.TAG_VALUE_STRESS_TOOL);
throw new RuntimeException(e);
}
});
long uploadDuration = System.currentTimeMillis() - uploadStartTime;
double throughputMb = uploadDuration == 0 ? 0 : (((double) totalShuffleWrittenBytes.get())/uploadDuration)*(1000.0/(1024.0*1024.0));
logger.info(String.format("Total written bytes: %s, records: %s, throughput: %s mb/s, total socket bytes: %s", totalShuffleWrittenBytes, totalShuffleWrittenRecords, throughputMb, totalSocketBytes));
if (mapThreadErrors.get() > 0) {
throw new RuntimeException("Number of errors in map threads: " + mapThreadErrors);
}
// Verify or delete files if necessary
if (!servers.isEmpty()) {
try {
int replicasForReadClient = numReplicas;
Map<Integer, Long> expectedTotalRecordsForEachPartition = new HashMap<>();
numPartitionRecords.entrySet().stream().forEach(t -> expectedTotalRecordsForEachPartition.put(t.getKey(), t.getValue().get()));
StreamReadClientVerify streamReadClientVerify = new StreamReadClientVerify();
streamReadClientVerify.setRssServers(serverDetails, replicasForReadClient);
streamReadClientVerify.setAppShuffleId(appShuffleId);
streamReadClientVerify.setNumPartitions(numPartitions);
streamReadClientVerify.setPartitionFanout(partitionFanout);
streamReadClientVerify.setExpectedTotalRecords(successShuffleWrittenRecords.get());
streamReadClientVerify.setExpectedTotalRecordsForEachPartition(expectedTotalRecordsForEachPartition);
streamReadClientVerify.setActionToSimulateBadServer(() -> {
synchronized (servers) {
synchronized (serverIdsToShutdownDuringShuffleRead) {
for (String serverId : serverIdsToShutdownDuringShuffleRead) {
StreamServer server = servers.stream().filter(t -> t != null).filter(t -> t.getServerId().equals(serverId)).findFirst().get();
logger.info(String.format("Simulate bad server during shuffle read by shutting down server: %s", server));
shutdownServer(server);
int index = servers.indexOf(server);
servers.set(index, null);
}
serverIdsToShutdownDuringShuffleRead.clear();
}
}
});
logger.info(String.format("Verifying reading from servers: %s", StringUtils.join(serverDetails, ", ")));
streamReadClientVerify.verifyRecords(usedPartitionIds.keySet(), fetchTaskAttemptIds);
logger.info(String.format("Verifying reading from servers: %s", StringUtils.join(serverDetails, ", ")));
} catch (Throwable ex) {
M3Stats.addException(ex, M3Stats.TAG_VALUE_STRESS_TOOL);
logger.error(String.format("Failed to verify reading from servers: %s", StringUtils.join(serverDetails, ", ")), ex);
throw ex;
} finally {
PooledWriteClientFactory.getInstance().shutdown();
synchronized (servers) {
servers.forEach(t -> {
if (t != null) {
shutdownServer(t);
}
});
}
if (deleteFiles) {
try {
logger.info(String.format("Deleting files: %s", StringUtils.join(serverRootDirs, ", ")));
deleteDirectories(serverRootDirs);
logger.info(String.format("Deleted files: %s", StringUtils.join(serverRootDirs, ", ")));
} catch (Throwable ex) {
M3Stats.addException(ex, M3Stats.TAG_VALUE_STRESS_TOOL);
logger.info("Got some error when deleting files: %s, ignored them");
}
}
}
}
if (mapThreadErrors.get() > 0) {
throw new RuntimeException("Number of errors in map threads: " + mapThreadErrors);
}
}