in streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java [393:534]
public static void optimizeRackAwareStandbyTasks(final RackAwareOptimizationParams optimizationParams,
final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) {
final ApplicationState applicationState = optimizationParams.applicationState;
final SortedSet<TaskId> standbyTasksToOptimize = getTasksToOptimize(kafkaStreamsAssignments, optimizationParams, AssignedTask.Type.STANDBY);
if (standbyTasksToOptimize.isEmpty()) {
return;
}
if (!canPerformRackAwareOptimization(applicationState, optimizationParams, AssignedTask.Type.STANDBY)) {
return;
}
initializeAssignmentsForAllClients(applicationState, kafkaStreamsAssignments);
final int crossRackTrafficCost =
optimizationParams.trafficCostOverride.orElseGet(() -> applicationState.assignmentConfigs()
.rackAwareTrafficCost()
.getAsInt());
final int nonOverlapCost =
optimizationParams.nonOverlapCostOverride.orElseGet(() -> applicationState.assignmentConfigs()
.rackAwareNonOverlapCost()
.getAsInt());
final Map<ProcessId, KafkaStreamsState> kafkaStreamsStates = applicationState.kafkaStreamsStates(false);
final Map<TaskId, Set<TaskTopicPartition>> topicPartitionsByTaskId =
applicationState.allTasks().values().stream().collect(Collectors.toMap(
TaskInfo::id,
t -> t.topicPartitions().stream().filter(TaskTopicPartition::isChangelog).collect(Collectors.toSet()))
);
final List<ProcessId> clientIds = new ArrayList<>(kafkaStreamsStates.keySet());
final long initialCost = computeTotalAssignmentCost(
topicPartitionsByTaskId,
new ArrayList<>(standbyTasksToOptimize),
clientIds,
kafkaStreamsAssignments,
kafkaStreamsStates,
crossRackTrafficCost,
nonOverlapCost,
true,
true
);
LOG.info("Assignment before standby task optimization has cost {}", initialCost);
final MoveStandbyTaskPredicate moveablePredicate = getStandbyTaskMovePredicate(applicationState);
final BiFunction<KafkaStreamsAssignment, KafkaStreamsAssignment, List<TaskId>> getMovableTasks = (source, destination) -> {
return source.tasks().values().stream()
.filter(task -> task.type() == AssignedTask.Type.STANDBY)
.filter(task -> !destination.tasks().containsKey(task.id()))
.filter(task -> {
final KafkaStreamsState sourceState = kafkaStreamsStates.get(source.processId());
final KafkaStreamsState destinationState = kafkaStreamsStates.get(source.processId());
return moveablePredicate.canMoveStandbyTask(sourceState, destinationState, task.id(), kafkaStreamsAssignments);
})
.map(AssignedTask::id)
.sorted()
.collect(Collectors.toList());
};
final long startTime = System.currentTimeMillis();
boolean taskMoved = true;
int round = 0;
final RackAwareGraphConstructor<KafkaStreamsAssignment> graphConstructor = RackAwareGraphConstructorFactory.create(
applicationState.assignmentConfigs().rackAwareAssignmentStrategy(), standbyTasksToOptimize);
while (taskMoved && round < STANDBY_OPTIMIZER_MAX_ITERATION) {
taskMoved = false;
round++;
for (int i = 0; i < kafkaStreamsAssignments.size(); i++) {
final ProcessId clientId1 = clientIds.get(i);
final KafkaStreamsAssignment assignment1 = kafkaStreamsAssignments.get(clientId1);
for (int j = i + 1; j < kafkaStreamsAssignments.size(); j++) {
final ProcessId clientId2 = clientIds.get(j);
final KafkaStreamsAssignment assignment2 = kafkaStreamsAssignments.get(clientId2);
final String rack1 = kafkaStreamsStates.get(clientId1).rackId().get();
final String rack2 = kafkaStreamsStates.get(clientId2).rackId().get();
// Cross rack traffic can not be reduced if racks are the same
if (rack1.equals(rack2)) {
continue;
}
final List<TaskId> movable1 = getMovableTasks.apply(assignment1, assignment2);
final List<TaskId> movable2 = getMovableTasks.apply(assignment2, assignment1);
// There's no needed to optimize if one is empty because the optimization
// can only swap tasks to keep the client's load balanced
if (movable1.isEmpty() || movable2.isEmpty()) {
continue;
}
final List<TaskId> moveableTaskIds = Stream.concat(movable1.stream(), movable2.stream())
.sorted()
.collect(Collectors.toList());
final List<ProcessId> clientsInTaskRedistributionAttempt = Stream.of(clientId1, clientId2)
.sorted()
.collect(Collectors.toList());
final AssignmentGraph assignmentGraph = buildTaskGraph(
kafkaStreamsAssignments,
kafkaStreamsStates,
moveableTaskIds,
clientsInTaskRedistributionAttempt,
topicPartitionsByTaskId,
crossRackTrafficCost,
nonOverlapCost,
true,
true,
graphConstructor
);
assignmentGraph.graph.solveMinCostFlow();
taskMoved |= graphConstructor.assignTaskFromMinCostFlow(
assignmentGraph.graph,
clientsInTaskRedistributionAttempt,
moveableTaskIds,
kafkaStreamsAssignments,
assignmentGraph.taskCountByClient,
assignmentGraph.clientByTask,
(assignment, taskId) -> assignment.assignTask(new AssignedTask(taskId, AssignedTask.Type.STANDBY)),
(assignment, taskId) -> assignment.removeTask(new AssignedTask(taskId, AssignedTask.Type.STANDBY)),
(assignment, taskId) -> assignment.tasks().containsKey(taskId) && assignment.tasks().get(taskId).type() == AssignedTask.Type.STANDBY
);
}
}
}
final long finalCost = computeTotalAssignmentCost(
topicPartitionsByTaskId,
new ArrayList<>(standbyTasksToOptimize),
clientIds,
kafkaStreamsAssignments,
kafkaStreamsStates,
crossRackTrafficCost,
nonOverlapCost,
true,
true
);
final long duration = System.currentTimeMillis() - startTime;
LOG.info("Assignment after {} rounds and {} milliseconds for standby task optimization is {}\n with cost {}",
round, duration, kafkaStreamsAssignments, finalCost);
}