public void run()

in src/main/java/com/uber/rss/tools/StreamServerStressTool.java [297:506]


    public void run() {
        scheduler = Executors.newScheduledThreadPool(1);
        scheduledMetricCollector = new ScheduledMetricCollector(serviceRegistry);
        scheduledMetricCollector.scheduleCollectingMetrics(scheduler, ServiceRegistry.DEFAULT_DATA_CENTER, ServiceRegistry.DEFAULT_TEST_CLUSTER);
        
        appShuffleId = new AppShuffleId(appId, appAttempt, 1);

        storage = new ShuffleFileStorage();

        // Start Remote Shuffle Service servers if no server hosts are provided
        if (serverHosts.isEmpty()) {
            for (int i = 0; i < numServers; i++) {
                startNewServer();
            }

            // only simulate bad servers when using multiple replicas
            if (numReplicas > 1) {
                List<ServerReplicationGroup> serverReplicationGroups = ServerReplicationGroupUtil.createReplicationGroups(serverDetails, numReplicas);
                serverReplicationGroups.forEach(t -> {
                    logger.info(String.format(String.format("Server replication group: %s", t)));
                });
                int halfSize = (int) Math.ceil((double) serverReplicationGroups.size() / 2.0);
                List<ServerReplicationGroup> firstHalf = serverReplicationGroups.stream().limit(halfSize).collect(Collectors.toList());
                List<ServerReplicationGroup> secondHalf = serverReplicationGroups.stream().skip(halfSize).collect(Collectors.toList());
                serverIdsToShutdownDuringShuffleWrite.addAll(firstHalf.stream().map(t -> t.getServers().get(0).getServerId()).collect(Collectors.toList()));
                serverIdsToShutdownDuringShuffleRead.addAll(secondHalf.stream().map(t -> t.getServers().get(0).getServerId()).collect(Collectors.toList()));
                logger.info(String.format(String.format(
                    "Servers to shutdown during shuffle write: %s, servers to shutdown during shuffle read: %s",
                    serverIdsToShutdownDuringShuffleWrite, serverIdsToShutdownDuringShuffleRead)));
            }
        }
        
        logger.info(String.format("Server root dirs: %s", StringUtils.join(serverRootDirs, ':')));

        // Generate test values to use
        
        List<byte[]> testValues = new ArrayList<>();

        testValues.add(null);
        testValues.add(new byte[0]);
        testValues.add("".getBytes(StandardCharsets.UTF_8));

        while (testValues.size() < numTestValues) {
            char ch = (char)('a' + random.nextInt(26));
            int repeats = random.nextInt(maxTestValueLen);
            String str = StringUtils.repeat(ch, repeats);
            testValues.add(str.getBytes(StandardCharsets.UTF_8));
        }
        
        // Simulate some map tasks which do not write any shuffle data
        
        Set<Integer> mapIdsWritingEmptyData = new HashSet<>();
        
        if (endMapId + 1 - startMapId > 2) {
            mapIdsWritingEmptyData.add(startMapId + random.nextInt(endMapId + 1 - startMapId));
            mapIdsWritingEmptyData.add(startMapId + random.nextInt(endMapId + 1 - startMapId));
        }

        // Create map task threads to upload shuffle data

        RateCounter[] rateCounters = new RateCounter[endMapId + 1 - startMapId];
        
        AtomicLong taskAttemptIdSeed = new AtomicLong();
        
        List<Integer> simulatedNumberOfAttemptsForMappers = new ArrayList<>();

        List<Long> fetchTaskAttemptIds = new ArrayList<>();

        ConcurrentHashMap<Integer, AtomicLong> numPartitionRecords = new ConcurrentHashMap<>();

        for (int i = startMapId; i <= endMapId; i++) {
            int value = random.nextInt(numMapAttempts) + 1;
            simulatedNumberOfAttemptsForMappers.add(value);
        }

        for (int i = startMapId; i <= endMapId; i++) {
            final int mapId = i;
            final int index = i - startMapId;
            rateCounters[index] = new RateCounter(5000);

            AppMapId appMapId = new AppMapId(appShuffleId.getAppId(), appShuffleId.getAppAttempt(), appShuffleId.getShuffleId(), mapId);

            int simulatedNumberOfAttempts = simulatedNumberOfAttemptsForMappers.get(index);

            Thread thread = new Thread(() -> {
                for (int attempt = 1; attempt <= simulatedNumberOfAttempts; attempt++) {
                    long taskAttemptId = taskAttemptIdSeed.getAndIncrement();
                    boolean isLastTaskAttempt = attempt == simulatedNumberOfAttempts;
                    boolean simulateEmptyData = mapIdsWritingEmptyData.contains(mapId);

                    if (isLastTaskAttempt) {
                      synchronized (fetchTaskAttemptIds) {
                          fetchTaskAttemptIds.add(taskAttemptId);
                      }
                    }

                    simulateMapperTask(testValues,
                        appMapId,
                        taskAttemptId,
                        isLastTaskAttempt,
                        simulateEmptyData,
                        rateCounters[index],
                        numPartitionRecords);
                }
            });

            thread.setName(String.format("[Map Thread %s]", appMapId));

            thread.setUncaughtExceptionHandler(new Thread.UncaughtExceptionHandler() {
                @Override
                public void uncaughtException(Thread t, Throwable e) {
                    logger.error(String.format("Mapper thread %s got exception", t.getName()), e);
                    mapThreadErrors.incrementAndGet();
                }
            });

            allMapThreads.add(thread);
        }

        long uploadStartTime = System.currentTimeMillis();
        
        // Start map task threads
        allMapThreads.forEach(t->t.start());
        
        // Wait for map tasks to finish
        allMapThreads.forEach(t-> {
            try {
                t.join();
            } catch (InterruptedException e) {
                M3Stats.addException(e, M3Stats.TAG_VALUE_STRESS_TOOL);
                throw new RuntimeException(e);
            }
        });

        long uploadDuration = System.currentTimeMillis() - uploadStartTime;
        double throughputMb = uploadDuration == 0 ? 0 : (((double) totalShuffleWrittenBytes.get())/uploadDuration)*(1000.0/(1024.0*1024.0));
        logger.info(String.format("Total written bytes: %s, records: %s, throughput: %s mb/s, total socket bytes: %s", totalShuffleWrittenBytes, totalShuffleWrittenRecords, throughputMb, totalSocketBytes));

        if (mapThreadErrors.get() > 0) {
            throw new RuntimeException("Number of errors in map threads: " + mapThreadErrors);
        }

        // Verify or delete files if necessary
        
        if (!servers.isEmpty()) {
            try {
                int replicasForReadClient = numReplicas;

                Map<Integer, Long> expectedTotalRecordsForEachPartition = new HashMap<>();
                numPartitionRecords.entrySet().stream().forEach(t -> expectedTotalRecordsForEachPartition.put(t.getKey(), t.getValue().get()));

                StreamReadClientVerify streamReadClientVerify = new StreamReadClientVerify();
                streamReadClientVerify.setRssServers(serverDetails, replicasForReadClient);
                streamReadClientVerify.setAppShuffleId(appShuffleId);
                streamReadClientVerify.setNumPartitions(numPartitions);
                streamReadClientVerify.setPartitionFanout(partitionFanout);
                streamReadClientVerify.setExpectedTotalRecords(successShuffleWrittenRecords.get());
                streamReadClientVerify.setExpectedTotalRecordsForEachPartition(expectedTotalRecordsForEachPartition);

                streamReadClientVerify.setActionToSimulateBadServer(() -> {
                    synchronized (servers) {
                        synchronized (serverIdsToShutdownDuringShuffleRead) {
                            for (String serverId : serverIdsToShutdownDuringShuffleRead) {
                                StreamServer server = servers.stream().filter(t -> t != null).filter(t -> t.getServerId().equals(serverId)).findFirst().get();
                                logger.info(String.format("Simulate bad server during shuffle read by shutting down server: %s", server));
                                shutdownServer(server);

                                int index = servers.indexOf(server);
                                servers.set(index, null);
                            }
                            serverIdsToShutdownDuringShuffleRead.clear();
                        }
                    }
                });

                logger.info(String.format("Verifying reading from servers: %s", StringUtils.join(serverDetails, ", ")));
                streamReadClientVerify.verifyRecords(usedPartitionIds.keySet(), fetchTaskAttemptIds);
                logger.info(String.format("Verifying reading from servers: %s", StringUtils.join(serverDetails, ", ")));
            } catch (Throwable ex) {
                M3Stats.addException(ex, M3Stats.TAG_VALUE_STRESS_TOOL);
                logger.error(String.format("Failed to verify reading from servers: %s", StringUtils.join(serverDetails, ", ")), ex);
                throw ex;
            } finally {
                PooledWriteClientFactory.getInstance().shutdown();

                synchronized (servers) {
                    servers.forEach(t -> {
                        if (t != null) {
                            shutdownServer(t);
                        }
                    });
                }
                
                if (deleteFiles) {
                    try {
                        logger.info(String.format("Deleting files: %s", StringUtils.join(serverRootDirs, ", ")));
                        deleteDirectories(serverRootDirs);
                        logger.info(String.format("Deleted files: %s", StringUtils.join(serverRootDirs, ", ")));
                    } catch (Throwable ex) {
                        M3Stats.addException(ex, M3Stats.TAG_VALUE_STRESS_TOOL);
                        logger.info("Got some error when deleting files: %s, ignored them");
                    }
                }
            }
        }

        if (mapThreadErrors.get() > 0) {
            throw new RuntimeException("Number of errors in map threads: " + mapThreadErrors);
        }
    }