pkg/task/taskset.go (366 lines of code) (raw):

// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package task import ( "context" "fmt" "strings" "github.com/GoogleCloudPlatform/khi/pkg/common/typedmap" "github.com/GoogleCloudPlatform/khi/pkg/task/taskid" "golang.org/x/exp/slices" ) type LabelPredicate[T any] = func(v T) bool // TaskSet is a collection of tasks. // It has several collection operation features for constructing the task graph to execute. type TaskSet struct { tasks []UntypedTask runnable bool } // sortTaskResult represents result of topological sorting tasks. type sortTaskResult struct { // TopologicalSortedTasks is the list of tasks in topological order. TopologicalSortedTasks []UntypedTask // MissingDependencies is the list of task reference Ids missed to resolve task dependencies. // This must be empty array when the sorting succeeded. MissingDependencies []taskid.UntypedTaskReference // CyclicDependencyPath is the path of task dependencies. Runnable became false if this field is "". CyclicDependencyPath string // Runnable indicate if this task graph is runnable or not. It means the tasks are sorted in topoligical order and all of input dependencies are resolved. Runnable bool } // NewTaskSet creates a new TaskSet with the given tasks. // Returns an error if there are duplicate task IDs. func NewTaskSet(tasks []UntypedTask) (*TaskSet, error) { taskIDs := map[string]struct{}{} for _, def := range tasks { id := def.UntypedID() if _, exist := taskIDs[id.String()]; exist { return nil, fmt.Errorf("multiple tasks have the same ID %s", id) } taskIDs[id.String()] = struct{}{} } return &TaskSet{ tasks: slices.Clone(tasks), runnable: false, }, nil } // Add a task definiton to current TaskSet. // Returns an error when duplicated task Id is assigned on the task. func (s *TaskSet) Add(newTask UntypedTask) error { taskIdMap := map[string]interface{}{} for _, task := range s.tasks { taskIdMap[task.UntypedID().String()] = struct{}{} } if _, exist := taskIdMap[newTask.UntypedID().String()]; exist { return fmt.Errorf("task id:%s is duplicated. Task ID must be unique", newTask.UntypedID()) } s.tasks = append(s.tasks, newTask) return nil } func (s *TaskSet) GetAll() []UntypedTask { return slices.Clone(s.tasks) } // Remove a task definition from current DefinitionSet. // Returns error if the definition does not exist func (s *TaskSet) Remove(id string) error { taskIdMap := map[string]interface{}{} for _, task := range s.tasks { taskIdMap[task.UntypedID().String()] = struct{}{} } if _, exist := taskIdMap[id]; !exist { return fmt.Errorf("task definition id:%s is not found in this set", id) } n := 0 for _, task := range s.tasks { if task.UntypedID().String() != id { s.tasks[n] = task n++ } } s.tasks = s.tasks[:n] return nil } // Get returns a task with the given string task ID notation. func (s *TaskSet) Get(id string) (UntypedTask, error) { for _, task := range s.tasks { if task.UntypedID().String() == id { return task, nil } } return nil, fmt.Errorf("task %s was not found", id) } // WrapGraph adds init task and done task to the runnable graph. // The init task named as `subgraphId`-init has the dependency provided in the subgraphDependency argument. And the init task will be dependency of the tasks that had no dependency before calling this method. // The done task named as `subgraphId`-done has the dependency of the tasks that were not dependent from any other tasks. // The result task set will be resolvable with `[the init task] -> [the other tasks] -> [the done task]` func (s *TaskSet) WrapGraph(subgraphId taskid.UntypedTaskImplementationID, subgraphDependency []taskid.UntypedTaskReference) (*TaskSet, error) { initTaskId := taskid.NewImplementationID(taskid.NewTaskReference[any](fmt.Sprintf("%s-init", subgraphId.ReferenceIDString())), subgraphId.GetTaskImplementationHash()) doneTaskId := taskid.NewImplementationID(taskid.NewTaskReference[any](fmt.Sprintf("%s-done", subgraphId.ReferenceIDString())), subgraphId.GetTaskImplementationHash()) rewiredTasks := []UntypedTask{} tasksNotDependentFromAnyMap := map[string]struct{}{} for _, t := range s.tasks { if len(t.Dependencies()) == 0 { capturedTask := t rewiredTask := &wrapGraphFirstTask{ task: capturedTask, dependencies: []taskid.UntypedTaskReference{initTaskId.Ref()}, } rewiredTasks = append(rewiredTasks, rewiredTask) } else { rewiredTasks = append(rewiredTasks, t) } tasksNotDependentFromAnyMap[t.UntypedID().GetUntypedReference().String()] = struct{}{} } for _, t := range s.tasks { for _, dep := range t.Dependencies() { delete(tasksNotDependentFromAnyMap, dep.String()) } } doneTaskDependencies := []taskid.UntypedTaskReference{ initTaskId.Ref(), } for k := range tasksNotDependentFromAnyMap { doneTaskDependencies = append(doneTaskDependencies, taskid.NewTaskReference[any](k)) } // Sort to make result stable slices.SortFunc(doneTaskDependencies, func(a, b taskid.UntypedTaskReference) int { return strings.Compare(a.String(), b.String()) }) initTask := NewTask(initTaskId, subgraphDependency, func(ctx context.Context) (any, error) { return nil, nil }) doneTask := NewTask(doneTaskId, doneTaskDependencies, func(ctx context.Context) (any, error) { return nil, nil }) rewiredTasks = append(rewiredTasks, initTask, doneTask) return NewTaskSet(rewiredTasks) } func (s *TaskSet) sortTaskGraph() *sortTaskResult { // To check if there were no cyclic task path or missing inputs, // perform the topological sorting algorithm known as Kahn's algorithm // Reference: https://en.wikipedia.org/wiki/Topological_sorting nonResolvedTasksMap := map[string]UntypedTask{} currentMissingTaskDependencies := map[string]map[string]interface{}{} currentMissingTaskSourceCount := map[string]int{} taskCount := 0 // Initialize currentMissingTaskDependencies and currentMissingTaskSourceCount for all tasks. for _, task := range s.tasks { taskID := task.UntypedID() sourceCount := 0 missingDependencies := map[string]interface{}{} for _, dependency := range task.Dependencies() { if _, found := missingDependencies[dependency.ReferenceIDString()]; !found { missingDependencies[dependency.ReferenceIDString()] = struct{}{} sourceCount += 1 } } currentMissingTaskDependencies[taskID.String()] = missingDependencies nonResolvedTasksMap[taskID.String()] = task currentMissingTaskSourceCount[taskID.String()] = sourceCount taskCount += 1 } topologicalSortedTasks := []UntypedTask{} for i := 0; i < taskCount; i++ { var nextTaskID string = "N/A" for _, taskId := range sortedMapKeys(nonResolvedTasksMap) { // Needs task sorting to get the same result every time. if currentMissingTaskSourceCount[taskId] == 0 { nextTaskID = taskId } } if nextTaskID != "N/A" { nextTask := nonResolvedTasksMap[nextTaskID] delete(nonResolvedTasksMap, nextTaskID) removingDependencyId := nextTask.UntypedID().ReferenceIDString() for taskId := range nonResolvedTasksMap { if _, exist := currentMissingTaskDependencies[taskId][removingDependencyId]; exist { delete(currentMissingTaskDependencies[taskId], removingDependencyId) currentMissingTaskSourceCount[taskId]-- } } topologicalSortedTasks = append(topologicalSortedTasks, nextTask) } else { // Failed to perform topological sort. // Gathers the cause of the failure. missingTaskIdsInMap := map[string]interface{}{} for taskId := range nonResolvedTasksMap { for dependency := range currentMissingTaskDependencies[taskId] { missingTaskIdsInMap[dependency] = struct{}{} } } for _, task := range nonResolvedTasksMap { delete(missingTaskIdsInMap, task.UntypedID().ReferenceIDString()) } missingSources := []taskid.UntypedTaskReference{} for source := range missingTaskIdsInMap { missingSources = append(missingSources, taskid.NewTaskReference[any](source)) } if len(missingSources) == 0 { // If there are no missing dependencies but still can't resolve the graph, // it means there is a cyclic dependency return getSortTaskResultWithDetailCyclicDependency(nonResolvedTasksMap, currentMissingTaskDependencies, missingSources) } return &sortTaskResult{ Runnable: false, TopologicalSortedTasks: nil, CyclicDependencyPath: "", MissingDependencies: missingSources, } } } return &sortTaskResult{ Runnable: true, TopologicalSortedTasks: topologicalSortedTasks, MissingDependencies: []taskid.UntypedTaskReference{}, CyclicDependencyPath: "", } } // ResolveTask generate a super set of this task set with adding required tasks from availableTaskSet. // The returned TaskSet of this method will be `runnable` and topologically sorted. func (s *TaskSet) ResolveTask(availableTaskSet *TaskSet) (*TaskSet, error) { sourceTaskSet := s sortResult := sourceTaskSet.sortTaskGraph() if sortResult.Runnable { return &TaskSet{tasks: sortResult.TopologicalSortedTasks, runnable: true}, nil } else { // the sourceTaskSet can't be topologically sorted with its own tasks. // Try to add missing dependencies from availableTaskSet complementedTask := []UntypedTask{} resolutionFailure := false var missingTaskId taskid.UntypedTaskReference for _, missingSource := range sortResult.MissingDependencies { matched := []UntypedTask{} for _, task := range availableTaskSet.tasks { missingSourceReference := missingSource.ReferenceIDString() if task.UntypedID().ReferenceIDString() == missingSourceReference { matched = append(matched, task) } } // sort matched tasks with selection priority for in case when there are 2 or more tasks can be usable for resolving required dependency maxPriority := -1 var maxPriorityTask UntypedTask for _, task := range matched { priority := typedmap.GetOrDefault(task.Labels(), LabelKeyTaskSelectionPriority, 0) if priority >= maxPriority { maxPriority = priority maxPriorityTask = task } } if maxPriorityTask != nil { complementedTask = append(complementedTask, maxPriorityTask) } else { resolutionFailure = true missingTaskId = missingSource } } if sortResult.CyclicDependencyPath != "" { return nil, fmt.Errorf("failed to resolve the task graph. \n The graph contains cyclic dependency\n%s", sortResult.CyclicDependencyPath) } if !resolutionFailure { tasks := append(slices.Clone(sourceTaskSet.tasks), complementedTask...) sourceTaskSet = &TaskSet{ tasks: tasks, runnable: false, } return sourceTaskSet.ResolveTask(availableTaskSet) } return nil, fmt.Errorf("failed to resolve the task set.\n Missing %s\nAvailable tasks:\n%v", missingTaskId.ReferenceIDString(), dumpTaskIDList(availableTaskSet)) } } // DumpGraphviz returns task graph as graphviz string for debugging purpose. // The generated string can be converted to DAG graph using `dot` command. func (s *TaskSet) DumpGraphviz() (string, error) { if !s.runnable { return "", fmt.Errorf("can't draw a graph for non runnable graph") } result := "digraph G {\n" result += "start [shape=\"diamond\",fillcolor=gray,style=filled]\n" for _, task := range s.tasks { // concept of the feature is not defined in task level, but it's better to be included in the dumpped graph. // The ID can't be referenced directly because of the circular dependency issue, thus this code define the ID with NewLabelKey feature := typedmap.GetOrDefault(task.Labels(), NewTaskLabelKey[bool]("khi.google.com/inspection/feature"), false) shape := "circle" if feature { shape = "doublecircle" } result += fmt.Sprintf("%s [shape=\"%s\",label=\"%s\"]\n", graphVizValidId(task.UntypedID().String()), shape, task.UntypedID()) } for _, task := range s.tasks { if len(task.Dependencies()) == 0 { result += fmt.Sprintf("start -> %s\n", graphVizValidId(task.UntypedID().String())) } } sourceRelation := map[string]UntypedTask{} for _, task := range s.tasks { sources := task.Dependencies() for _, source := range sources { sourceTask := sourceRelation[source.ReferenceIDString()] result += fmt.Sprintf("%s -> %s\n", graphVizValidId(sourceTask.UntypedID().String()), graphVizValidId(task.UntypedID().String())) } sourceRelation[task.UntypedID().ReferenceIDString()] = task } result += "}" return result, nil } func sortedMapKeys[T any](inputMap map[string]T) []string { result := []string{} for key := range inputMap { result = append(result, key) } slices.SortFunc(result, strings.Compare) return result } func graphVizValidId(id string) string { return strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(id, "-", "_"), "/", "_"), ".", "_"), "#", "_") } // getSortTaskResultWithDetailCyclicDependency detects and reports cyclic dependencies in the task graph. // It returns a sortTaskResult with the details of the cyclic dependency. func getSortTaskResultWithDetailCyclicDependency( nonResolvedTasksMap map[string]UntypedTask, currentMissingTaskDependencies map[string]map[string]interface{}, missingSources []taskid.UntypedTaskReference, ) *sortTaskResult { for _, taskID := range sortedMapKeys(nonResolvedTasksMap) { dependentFrom := map[string]string{} // A map tracks the path where the task depended from. dependentFrom[taskID] = "START" queue := map[string]struct{}{} queue[taskID] = struct{}{} for len(queue) > 0 { nextTaskID := sortedMapKeys(queue)[0] delete(queue, nextTaskID) for dependency := range currentMissingTaskDependencies[nextTaskID] { prevParent := "" for visitedTask := range dependentFrom { // The task ID contains implementation hash(#default), it should match with the prefix. if strings.HasPrefix(visitedTask, dependency) { prevParent = dependentFrom[visitedTask] break } } if prevParent != "" { if prevParent == "START" { // now we found the path to loop back to the START. trace back the cyclic path. path := []string{} queue := map[string]struct{}{} queue[nextTaskID] = struct{}{} for len(queue) > 0 { nextTaskID := sortedMapKeys(queue)[0] if nextTaskID == "START" { break } delete(queue, nextTaskID) path = append(path, nextTaskID) queue[dependentFrom[nextTaskID]] = struct{}{} } return &sortTaskResult{ Runnable: false, TopologicalSortedTasks: nil, CyclicDependencyPath: fmt.Sprintf("... -> %s] -> [%s] -> [%s -> ...", path[len(path)-1], strings.Join(path, " -> "), path[0]), MissingDependencies: missingSources, } } } else { for taskID := range nonResolvedTasksMap { if strings.HasPrefix(taskID, dependency) { dependentFrom[taskID] = nextTaskID queue[taskID] = struct{}{} break } } } } } } nonResolvedTaskKeys := sortedMapKeys(nonResolvedTasksMap) missingSourceDependencyInfo := []string{} for missingDependencyKey, missingDependency := range currentMissingTaskDependencies { missingSourceDependencyInfo = append(missingSourceDependencyInfo, fmt.Sprintf("%s -> %v", missingDependencyKey, sortedMapKeys(missingDependency))) } // This should be unreachable if the graph has a cyclic dependency panic(fmt.Sprintf("unreachable. findCyclicDependency was called on a task graph with a task graph without any cyclic dependency. \n debug info: \n non resolved tasks: %v \n missing dependencies: %s", nonResolvedTaskKeys, missingSourceDependencyInfo)) } func dumpTaskIDList(taskSet *TaskSet) string { taskIDs := []string{} for _, task := range taskSet.tasks { taskIDs = append(taskIDs, task.UntypedID().String()) } slices.SortFunc(taskIDs, strings.Compare) result := "" for _, taskID := range taskIDs { result += fmt.Sprintf("- %s\n", taskID) } return result } // wrapGraphFirstTask is an implementation of Task to rewrite its dependency for wrapping graphs as a sub graph. // This is only used in the WrapGraph method. type wrapGraphFirstTask struct { task UntypedTask dependencies []taskid.UntypedTaskReference } // Dependencies implements Task. func (w *wrapGraphFirstTask) Dependencies() []taskid.UntypedTaskReference { return w.dependencies } // ID implements Task. func (w *wrapGraphFirstTask) ID() taskid.TaskImplementationID[any] { untypedID := w.task.UntypedID() return taskid.NewImplementationID(taskid.NewTaskReference[any](untypedID.GetUntypedReference().String()), untypedID.GetTaskImplementationHash()) } // Labels implements Task. func (w *wrapGraphFirstTask) Labels() *typedmap.ReadonlyTypedMap { return w.task.Labels() } // Run implements Task. func (w *wrapGraphFirstTask) Run(ctx context.Context) (any, error) { return w.task.UntypedRun(ctx) } func (w *wrapGraphFirstTask) UntypedRun(ctx context.Context) (any, error) { return w.Run(ctx) } func (w *wrapGraphFirstTask) UntypedID() taskid.UntypedTaskImplementationID { return w.task.UntypedID() } var _ Task[any] = (*wrapGraphFirstTask)(nil)