public static void optimizeRackAwareStandbyTasks()

in streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java [393:534]


    public static void optimizeRackAwareStandbyTasks(final RackAwareOptimizationParams optimizationParams,
                                                     final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) {
        final ApplicationState applicationState = optimizationParams.applicationState;
        final SortedSet<TaskId> standbyTasksToOptimize = getTasksToOptimize(kafkaStreamsAssignments, optimizationParams, AssignedTask.Type.STANDBY);
        if (standbyTasksToOptimize.isEmpty()) {
            return;
        }

        if (!canPerformRackAwareOptimization(applicationState, optimizationParams, AssignedTask.Type.STANDBY)) {
            return;
        }

        initializeAssignmentsForAllClients(applicationState, kafkaStreamsAssignments);

        final int crossRackTrafficCost =
            optimizationParams.trafficCostOverride.orElseGet(() -> applicationState.assignmentConfigs()
                .rackAwareTrafficCost()
                .getAsInt());
        final int nonOverlapCost =
            optimizationParams.nonOverlapCostOverride.orElseGet(() -> applicationState.assignmentConfigs()
                .rackAwareNonOverlapCost()
                .getAsInt());

        final Map<ProcessId, KafkaStreamsState> kafkaStreamsStates = applicationState.kafkaStreamsStates(false);

        final Map<TaskId, Set<TaskTopicPartition>> topicPartitionsByTaskId =
            applicationState.allTasks().values().stream().collect(Collectors.toMap(
                TaskInfo::id,
                t -> t.topicPartitions().stream().filter(TaskTopicPartition::isChangelog).collect(Collectors.toSet()))
            );

        final List<ProcessId> clientIds = new ArrayList<>(kafkaStreamsStates.keySet());
        final long initialCost = computeTotalAssignmentCost(
            topicPartitionsByTaskId,
            new ArrayList<>(standbyTasksToOptimize),
            clientIds,
            kafkaStreamsAssignments,
            kafkaStreamsStates,
            crossRackTrafficCost,
            nonOverlapCost,
            true,
            true
        );
        LOG.info("Assignment before standby task optimization has cost {}", initialCost);

        final MoveStandbyTaskPredicate moveablePredicate = getStandbyTaskMovePredicate(applicationState);
        final BiFunction<KafkaStreamsAssignment, KafkaStreamsAssignment, List<TaskId>> getMovableTasks = (source, destination) -> {
            return source.tasks().values().stream()
                .filter(task -> task.type() == AssignedTask.Type.STANDBY)
                .filter(task -> !destination.tasks().containsKey(task.id()))
                .filter(task -> {
                    final KafkaStreamsState sourceState = kafkaStreamsStates.get(source.processId());
                    final KafkaStreamsState destinationState = kafkaStreamsStates.get(source.processId());
                    return moveablePredicate.canMoveStandbyTask(sourceState, destinationState, task.id(), kafkaStreamsAssignments);
                })
                .map(AssignedTask::id)
                .sorted()
                .collect(Collectors.toList());
        };

        final long startTime = System.currentTimeMillis();
        boolean taskMoved = true;
        int round = 0;
        final RackAwareGraphConstructor<KafkaStreamsAssignment> graphConstructor = RackAwareGraphConstructorFactory.create(
            applicationState.assignmentConfigs().rackAwareAssignmentStrategy(), standbyTasksToOptimize);
        while (taskMoved && round < STANDBY_OPTIMIZER_MAX_ITERATION) {
            taskMoved = false;
            round++;
            for (int i = 0; i < kafkaStreamsAssignments.size(); i++) {
                final ProcessId clientId1 = clientIds.get(i);
                final KafkaStreamsAssignment assignment1 = kafkaStreamsAssignments.get(clientId1);
                for (int j = i + 1; j < kafkaStreamsAssignments.size(); j++) {
                    final ProcessId clientId2 = clientIds.get(j);
                    final KafkaStreamsAssignment assignment2 = kafkaStreamsAssignments.get(clientId2);

                    final String rack1 = kafkaStreamsStates.get(clientId1).rackId().get();
                    final String rack2 = kafkaStreamsStates.get(clientId2).rackId().get();
                    // Cross rack traffic can not be reduced if racks are the same
                    if (rack1.equals(rack2)) {
                        continue;
                    }

                    final List<TaskId> movable1 = getMovableTasks.apply(assignment1, assignment2);
                    final List<TaskId> movable2 = getMovableTasks.apply(assignment2, assignment1);

                    // There's no needed to optimize if one is empty because the optimization
                    // can only swap tasks to keep the client's load balanced
                    if (movable1.isEmpty() || movable2.isEmpty()) {
                        continue;
                    }

                    final List<TaskId> moveableTaskIds = Stream.concat(movable1.stream(), movable2.stream())
                        .sorted()
                        .collect(Collectors.toList());
                    final List<ProcessId> clientsInTaskRedistributionAttempt = Stream.of(clientId1, clientId2)
                        .sorted()
                        .collect(Collectors.toList());

                    final AssignmentGraph assignmentGraph = buildTaskGraph(
                        kafkaStreamsAssignments,
                        kafkaStreamsStates,
                        moveableTaskIds,
                        clientsInTaskRedistributionAttempt,
                        topicPartitionsByTaskId,
                        crossRackTrafficCost,
                        nonOverlapCost,
                        true,
                        true,
                        graphConstructor
                    );
                    assignmentGraph.graph.solveMinCostFlow();

                    taskMoved |= graphConstructor.assignTaskFromMinCostFlow(
                        assignmentGraph.graph,
                        clientsInTaskRedistributionAttempt,
                        moveableTaskIds,
                        kafkaStreamsAssignments,
                        assignmentGraph.taskCountByClient,
                        assignmentGraph.clientByTask,
                        (assignment, taskId) -> assignment.assignTask(new AssignedTask(taskId, AssignedTask.Type.STANDBY)),
                        (assignment, taskId) -> assignment.removeTask(new AssignedTask(taskId, AssignedTask.Type.STANDBY)),
                        (assignment, taskId) -> assignment.tasks().containsKey(taskId) && assignment.tasks().get(taskId).type() == AssignedTask.Type.STANDBY
                    );
                }
            }
        }
        final long finalCost = computeTotalAssignmentCost(
            topicPartitionsByTaskId,
            new ArrayList<>(standbyTasksToOptimize),
            clientIds,
            kafkaStreamsAssignments,
            kafkaStreamsStates,
            crossRackTrafficCost,
            nonOverlapCost,
            true,
            true
        );

        final long duration = System.currentTimeMillis() - startTime;
        LOG.info("Assignment after {} rounds and {} milliseconds for standby task optimization is {}\n with cost {}",
            round, duration, kafkaStreamsAssignments, finalCost);
    }