in src/main/java/com/uber/rss/tools/StreamReadClientVerify.java [98:175]
public void verifyRecords(Collection<Integer> partitionIds, Collection<Long> fetchTaskAttemptIds) {
AtomicLong totalReadRecords = new AtomicLong();
if (partitionIds == null) {
partitionIds = IntStream.range(0, numPartitions).boxed().collect(Collectors.toList());
logger.info(String.format("Verifying record for partitions: [%s, %s)", 0, numPartitions));
} else {
logger.info(String.format("Verifying record for partitions: %s", StringUtils.join(partitionIds, ",")));
}
for (int partition: partitionIds) {
AppShufflePartitionId appShufflePartitionId = new AppShufflePartitionId(
appId, appAttempt, shuffleId, partition);
int socketTimeoutMillis = 120 * 1000;
int dataAvailableWaitTime = socketTimeoutMillis * 3;
int dataAvailablePollInterval = 10;
boolean checkDataConsistency = true;
MultiServerReadClient readClient;
List<ServerReplicationGroup> serverReplicationGroups;
serverReplicationGroups = ServerReplicationGroupUtil.createReplicationGroupsForPartition(rssServers, numReplicas, partition, partitionFanout);
readClient = new MultiServerSocketReadClient(serverReplicationGroups,
socketTimeoutMillis,
new ClientRetryOptions(dataAvailablePollInterval, dataAvailableWaitTime, serverDetail->serverDetail),
"user1",
appShufflePartitionId,
new ReadClientDataOptions(
fetchTaskAttemptIds,
dataAvailablePollInterval,
dataAvailableWaitTime),
checkDataConsistency);
logger.info(String.format("Connecting replicated read client: %s", readClient));
readClient.connect();
try {
long numReadPartitionRecords = 0;
TaskDataBlock record = readClient.readDataBlock();
while (record != null) {
numReadPartitionRecords++;
long totalReadRecordsValue = totalReadRecords.incrementAndGet();
if (totalReadRecordsValue == expectedTotalRecords/2) {
if (actionToSimulateBadServer != null) {
logger.info("Simulate bad server during shuffle read");
actionToSimulateBadServer.run();
}
}
if (record.getPayload() != null && record.getPayload().length > maxValueLen) {
throw new RuntimeException(String.format(
"Read wrong value len %s after reading %s records for %s from server %s",
record.getPayload(), numReadPartitionRecords, appShufflePartitionId, serverReplicationGroups));
}
record = readClient.readDataBlock();
}
logger.info(String.format("Closing read client for %s", appShufflePartitionId));
long expectedNumRecords = expectedTotalRecordsForEachPartition.getOrDefault(partition, 0L);
if (numReadPartitionRecords != expectedNumRecords) {
throw new RuntimeException(String.format(
"Verify error for partition %s, servers %s, expected records: %s, actual records: %s",
appShufflePartitionId, serverReplicationGroups, expectedNumRecords, numReadPartitionRecords));
}
logger.info(String.format("Verified %s records for %s from server %s",
numReadPartitionRecords, appShufflePartitionId, serverReplicationGroups));
} finally {
readClient.close();
}
}
String logMsg = String.format("Total expected records: %s, total records read from servers: %s", expectedTotalRecords, totalReadRecords);
logger.info(logMsg);
if (expectedTotalRecords != 0 && expectedTotalRecords != totalReadRecords.get()) {
throw new RuntimeException(logMsg);
}
}