pkg/task/runner.go (179 lines of code) (raw):

// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package task import ( "context" "fmt" "log/slog" "sync" "time" "github.com/GoogleCloudPlatform/khi/pkg/common/khictx" "github.com/GoogleCloudPlatform/khi/pkg/common/typedmap" task_contextkey "github.com/GoogleCloudPlatform/khi/pkg/task/contextkey" task_interface "github.com/GoogleCloudPlatform/khi/pkg/task/inteface" "github.com/GoogleCloudPlatform/khi/pkg/task/taskid" "golang.org/x/sync/errgroup" ) var _ task_interface.TaskRunner = (*LocalRunner)(nil) type LocalRunner struct { resolvedTaskSet *TaskSet resultVariable *typedmap.TypedMap resultError error started bool stopped bool taskWaiters *sync.Map // sync.Map[string(taskRefID), sync.RWMutex], runner acquire the write lock at the beginning. All dependents will acquire read lock, it will be released when the task run finished. waiter chan interface{} taskStatuses []*LocalRunnerTaskStat } type LocalRunnerTaskStat struct { Phase string Error error StartTime time.Time EndTime time.Time } const ( LocalRunnerTaskStatPhaseWaiting = "WAITING" LocalRunnerTaskStatPhaseRunning = "RUNNING" LocalRunnerTaskStatPhaseStopped = "STOPPED" ) func (r *LocalRunner) Wait() <-chan interface{} { return r.waiter } // Result implements Runner. func (r *LocalRunner) Result() (*typedmap.ReadonlyTypedMap, error) { if !r.stopped { return nil, fmt.Errorf("this task runner hasn't finished yet") } if r.resultError != nil { return nil, r.resultError } return r.resultVariable.AsReadonly(), nil } // Run implements Runner. func (r *LocalRunner) Run(ctx context.Context) error { if r.started { return fmt.Errorf("this task is already started before") } go func() { defer r.markDone() // Setting up graph context r.resultVariable = typedmap.NewTypedMap() ctx = khictx.WithValue(ctx, task_contextkey.TaskResultMapContextKey, r.resultVariable) tasks := r.resolvedTaskSet.GetAll() cancelableCtx, cancel := context.WithCancel(ctx) currentErrGrp, currentErrCtx := errgroup.WithContext(cancelableCtx) for i := range tasks { taskDefIndex := i currentErrGrp.Go(func() error { err := r.runTask(currentErrCtx, taskDefIndex) if err != nil { cancel() return err } return nil }) } err := currentErrGrp.Wait() if err != nil { r.resultError = err } cancel() }() return nil } func (r *LocalRunner) runTask(graphCtx context.Context, taskDefIndex int) error { task := r.resolvedTaskSet.GetAll()[taskDefIndex] sources := task.Dependencies() taskStatus := r.taskStatuses[taskDefIndex] taskCtx := khictx.WithValue(graphCtx, task_contextkey.TaskImplementationIDContextKey, task.UntypedID()) slog.DebugContext(taskCtx, fmt.Sprintf("task %s started", task.UntypedID().String())) r.waitDependencies(taskCtx, sources) if taskCtx.Err() == context.Canceled { return context.Canceled } taskStatus.StartTime = time.Now() taskStatus.Phase = LocalRunnerTaskStatPhaseRunning slog.DebugContext(taskCtx, fmt.Sprintf("task %s started", task.UntypedID())) result, err := task.UntypedRun(taskCtx) taskStatus.Phase = LocalRunnerTaskStatPhaseStopped taskStatus.EndTime = time.Now() slog.DebugContext(taskCtx, fmt.Sprintf("task %s stopped after %f sec", task.UntypedID(), taskStatus.EndTime.Sub(taskStatus.StartTime).Seconds())) taskStatus.Error = err if taskCtx.Err() == context.Canceled { return context.Canceled } if err != nil { detailedErr := r.wrapWithTaskError(err, task) r.resultError = detailedErr slog.ErrorContext(taskCtx, err.Error()) return detailedErr } typedmap.Set(r.resultVariable, typedmap.NewTypedKey[any](task.UntypedID().GetUntypedReference().ReferenceIDString()), result) taskWaiter, _ := r.taskWaiters.Load(task.UntypedID().GetUntypedReference().String()) taskWaiter.(*sync.RWMutex).Unlock() return nil } func (r *LocalRunner) TaskStatuses() []*LocalRunnerTaskStat { return r.taskStatuses } func newLocalRunnerTaskStatus() *LocalRunnerTaskStat { return &LocalRunnerTaskStat{ Phase: LocalRunnerTaskStatPhaseWaiting, } } func NewLocalRunner(taskSet *TaskSet) (*LocalRunner, error) { if !taskSet.runnable { return nil, fmt.Errorf("given taskset must be runnable") } taskStatuses := []*LocalRunnerTaskStat{} taskWaiters := sync.Map{} for i := 0; i < len(taskSet.tasks); i++ { taskStatuses = append(taskStatuses, newLocalRunnerTaskStatus()) // lock the task waiter until its task finished. waiter := sync.RWMutex{} waiter.Lock() taskWaiters.Store(taskSet.tasks[i].UntypedID().ReferenceIDString(), &waiter) } return &LocalRunner{ resolvedTaskSet: taskSet, started: false, resultVariable: nil, resultError: nil, stopped: false, taskWaiters: &taskWaiters, waiter: make(chan interface{}), taskStatuses: taskStatuses, }, nil } func (r *LocalRunner) markDone() { r.stopped = true close(r.waiter) r.taskWaiters.Range(func(key, value any) bool { mutex, _ := value.(*sync.RWMutex) if !mutex.TryRLock() { mutex.Unlock() } return true }) } func (r *LocalRunner) waitDependencies(ctx context.Context, dependencies []taskid.UntypedTaskReference) error { for _, dependency := range dependencies { select { // wait for getting the RLock for the task result, or context cancel case <-ctx.Done(): return ctx.Err() case <-func() chan struct{} { ch := make(chan struct{}) go func() { waiter, _ := r.taskWaiters.Load(dependency.ReferenceIDString()) taskWaiter := waiter.(*sync.RWMutex) taskWaiter.RLock() close(ch) }() return ch }(): continue } } return nil } func (r *LocalRunner) wrapWithTaskError(err error, task UntypedTask) error { errMsg := fmt.Sprintf("failed to run a task graph.\n task ID=%s got an error. \n ERROR:\n%v", task.UntypedID(), err) return fmt.Errorf("%s", errMsg) } // GetTaskResultFromLocalRunner returns task results from the local runner task results. func GetTaskResultFromLocalRunner[TaskResult any](runner *LocalRunner, taskRef taskid.TaskReference[TaskResult]) (TaskResult, bool) { return typedmap.Get(runner.resultVariable, typedmap.NewTypedKey[TaskResult](taskRef.String())) }