pkg/inspection/runner.go (383 lines of code) (raw):
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package inspection
import (
"context"
"errors"
"fmt"
"log/slog"
"math/rand"
"sync"
"time"
"github.com/GoogleCloudPlatform/khi/pkg/common/filter"
"github.com/GoogleCloudPlatform/khi/pkg/common/khictx"
"github.com/GoogleCloudPlatform/khi/pkg/common/typedmap"
inspection_task_contextkey "github.com/GoogleCloudPlatform/khi/pkg/inspection/contextkey"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/inspectiondata"
inspection_task_interface "github.com/GoogleCloudPlatform/khi/pkg/inspection/interface"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/metadata"
error_metadata "github.com/GoogleCloudPlatform/khi/pkg/inspection/metadata/error"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/metadata/form"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/metadata/header"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/metadata/logger"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/metadata/plan"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/metadata/progress"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/metadata/query"
inspection_task "github.com/GoogleCloudPlatform/khi/pkg/inspection/task"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/task/serializer"
"github.com/GoogleCloudPlatform/khi/pkg/lifecycle"
"github.com/GoogleCloudPlatform/khi/pkg/parameters"
"github.com/GoogleCloudPlatform/khi/pkg/task"
task_contextkey "github.com/GoogleCloudPlatform/khi/pkg/task/contextkey"
task_interface "github.com/GoogleCloudPlatform/khi/pkg/task/inteface"
"github.com/GoogleCloudPlatform/khi/pkg/task/taskid"
)
var inspectionRunnerGlobalSharedMap = typedmap.NewTypedMap()
type InspectionTaskRunner struct {
inspectionServer *InspectionTaskServer
ID string
enabledFeatures map[string]bool
availableTasks *task.TaskSet
featureTasks *task.TaskSet
requiredTasks *task.TaskSet
runner task_interface.TaskRunner
runnerLock sync.Mutex
metadata *typedmap.ReadonlyTypedMap
cancel context.CancelFunc
inspectionSharedMap *typedmap.TypedMap
currentInspectionType string
}
func NewInspectionRunner(server *InspectionTaskServer) *InspectionTaskRunner {
return &InspectionTaskRunner{
inspectionServer: server,
ID: generateRandomString(),
enabledFeatures: map[string]bool{},
availableTasks: nil,
featureTasks: nil,
requiredTasks: nil,
runner: nil,
runnerLock: sync.Mutex{},
metadata: nil,
inspectionSharedMap: typedmap.NewTypedMap(),
cancel: nil,
currentInspectionType: "N/A",
}
}
func (i *InspectionTaskRunner) Started() bool {
return i.runner != nil
}
func (i *InspectionTaskRunner) SetInspectionType(inspectionType string) error {
typeFound := false
for _, inspection := range i.inspectionServer.inspectionTypes {
if inspection.Id == inspectionType {
typeFound = true
break
}
}
if !typeFound {
return fmt.Errorf("inspection type %s was not found", inspectionType)
}
i.availableTasks = task.Subset(i.inspectionServer.RootTaskSet, filter.NewContainsElementFilter(inspection_task.LabelKeyInspectionTypes, inspectionType, true))
defaultFeatures := task.Subset(i.availableTasks, filter.NewEnabledFilter(inspection_task.LabelKeyInspectionDefaultFeatureFlag, false))
i.requiredTasks = task.Subset(i.availableTasks, filter.NewEnabledFilter(inspection_task.LabelKeyInspectionRequiredFlag, false))
defaultFeatureIds := []string{}
for _, featureTask := range defaultFeatures.GetAll() {
defaultFeatureIds = append(defaultFeatureIds, featureTask.UntypedID().String())
}
i.currentInspectionType = inspectionType
return i.SetFeatureList(defaultFeatureIds)
}
func (i *InspectionTaskRunner) FeatureList() ([]FeatureListItem, error) {
if i.availableTasks == nil {
return nil, fmt.Errorf("inspection type is not yet initialized")
}
featureSet := task.Subset(i.availableTasks, filter.NewEnabledFilter(inspection_task.LabelKeyInspectionFeatureFlag, false))
features := []FeatureListItem{}
for _, featureTask := range featureSet.GetAll() {
label := typedmap.GetOrDefault(featureTask.Labels(), inspection_task.LabelKeyFeatureTaskTitle, fmt.Sprintf("No label Set!(%s)", featureTask.UntypedID()))
description := typedmap.GetOrDefault(featureTask.Labels(), inspection_task.LabelKeyFeatureTaskDescription, "")
enabled := false
if v, exist := i.enabledFeatures[featureTask.UntypedID().String()]; exist && v {
enabled = true
}
features = append(features, FeatureListItem{
Id: featureTask.UntypedID().String(),
Label: label,
Description: description,
Enabled: enabled,
})
}
return features, nil
}
func (i *InspectionTaskRunner) SetFeatureList(featureList []string) error {
featureTasks := []task.UntypedTask{}
for _, featureId := range featureList {
featureTask, err := i.availableTasks.Get(featureId)
if err != nil {
return err
}
if !typedmap.GetOrDefault(featureTask.Labels(), inspection_task.LabelKeyInspectionFeatureFlag, false) {
return fmt.Errorf("task `%s` is not marked as a feature but requested to be included in the feature set of an inspection", featureTask.UntypedID())
}
featureTasks = append(featureTasks, featureTask)
}
featureTaskSet, err := task.NewTaskSet(featureTasks)
if err != nil {
return err
}
i.enabledFeatures = map[string]bool{}
for _, feature := range featureList {
i.enabledFeatures[feature] = true
}
i.featureTasks = featureTaskSet
return nil
}
// UpdateFeatureMap updates the enabledFeatures and featureDefinitions
// inputs:
// featureMap: map of featureId and bool. If the value is true, the feature is enabled.
func (i *InspectionTaskRunner) UpdateFeatureMap(featureMap map[string]bool) error {
for featureId := range featureMap {
task, err := i.availableTasks.Get(featureId)
if err != nil {
return err
}
if !typedmap.GetOrDefault(task.Labels(), inspection_task.LabelKeyInspectionFeatureFlag, false) {
return fmt.Errorf("task `%s` is not marked as a feature but requested to be included in the feature set of an inspection", task.UntypedID())
}
if featureMap[featureId] {
i.featureTasks.Add(task)
} else {
i.featureTasks.Remove(featureId)
}
i.enabledFeatures[featureId] = featureMap[featureId]
}
return nil
}
// withRunContextValues returns a context with the value specific to a single run of task.
func (i *InspectionTaskRunner) withRunContextValues(ctx context.Context, runMode inspection_task_interface.InspectionTaskMode, taskInput map[string]any) context.Context {
rid := generateRandomString()
runCtx := khictx.WithValue(ctx, inspection_task_contextkey.InspectionTaskRunID, rid)
runCtx = khictx.WithValue(runCtx, inspection_task_contextkey.InspectionTaskInspectionID, i.ID)
runCtx = khictx.WithValue(runCtx, inspection_task_contextkey.InspectionSharedMap, i.inspectionSharedMap)
runCtx = khictx.WithValue(runCtx, inspection_task_contextkey.GlobalSharedMap, inspectionRunnerGlobalSharedMap)
runCtx = khictx.WithValue(runCtx, inspection_task_contextkey.InspectionTaskInput, taskInput)
return khictx.WithValue(runCtx, inspection_task_contextkey.InspectionTaskMode, runMode)
}
func (i *InspectionTaskRunner) Run(ctx context.Context, req *inspection_task.InspectionRequest) error {
defer i.runnerLock.Unlock()
i.runnerLock.Lock()
if i.runner != nil {
return fmt.Errorf("this task is already started")
}
currentInspectionType := i.inspectionServer.GetInspectionType(i.currentInspectionType)
runnableTaskGraph, err := i.resolveTaskGraph()
if err != nil {
return err
}
runCtx := i.withRunContextValues(ctx, inspection_task_interface.TaskModeRun, req.Values)
runMetadata := i.generateMetadataForRun(runCtx, &header.Header{
InspectTimeUnixSeconds: time.Now().Unix(),
InspectionType: currentInspectionType.Name,
InspectionTypeIconPath: currentInspectionType.Icon,
SuggestedFileName: "unnamed.khi",
}, runnableTaskGraph)
runCtx = khictx.WithValue(runCtx, inspection_task_contextkey.InspectionRunMetadata, runMetadata)
cancelableCtx, cancel := context.WithCancel(runCtx)
i.cancel = cancel
runner, err := task.NewLocalRunner(runnableTaskGraph)
if err != nil {
return err
}
i.runner = runner
i.metadata = runMetadata
lifecycle.Default.NotifyInspectionStart(khictx.MustGetValue(runCtx, inspection_task_contextkey.InspectionTaskRunID), currentInspectionType.Name)
err = i.runner.Run(cancelableCtx)
if err != nil {
return err
}
go func() {
<-i.runner.Wait()
progress, found := typedmap.Get(i.metadata, progress.ProgressMetadataKey)
if !found {
slog.ErrorContext(runCtx, "progress metadata was not found")
}
status := ""
resultSize := 0
if result, err := i.runner.Result(); err != nil {
if errors.Is(cancelableCtx.Err(), context.Canceled) {
progress.Cancel()
status = "cancel"
} else {
progress.Error()
status = "error"
}
slog.WarnContext(runCtx, fmt.Sprintf("task %s was finished with an error\n%s", i.ID, err))
} else {
progress.Done()
status = "done"
history, found := typedmap.Get(result, typedmap.NewTypedKey[inspectiondata.Store](serializer.SerializerTaskID.ReferenceIDString()))
if !found {
slog.ErrorContext(runCtx, fmt.Sprintf("Failed to get generated history after the completion\n%s", err))
}
if history == nil {
slog.ErrorContext(runCtx, "Failed to get the serializer result. Result is nil!")
} else {
resultSize, err = history.GetInspectionResultSizeInBytes()
if err != nil {
slog.ErrorContext(runCtx, fmt.Sprintf("Failed to get the serialized result size\n%s", err))
}
}
}
lifecycle.Default.NotifyInspectionEnd(khictx.MustGetValue(runCtx, inspection_task_contextkey.InspectionTaskRunID), currentInspectionType.Name, status, resultSize)
}()
return nil
}
func (i *InspectionTaskRunner) Result() (*InspectionRunResult, error) {
if i.runner == nil {
return nil, fmt.Errorf("this task is not yet started")
}
v, err := i.runner.Result()
if err != nil {
return nil, err
}
inspectionDataStore, found := typedmap.Get(v, typedmap.NewTypedKey[inspectiondata.Store](serializer.SerializerTaskID.ReferenceIDString()))
if !found {
return nil, fmt.Errorf("failed to get the serializer result")
}
md, err := metadata.GetSerializableSubsetMapFromMetadataSet(i.metadata, filter.NewEnabledFilter(metadata.LabelKeyIncludedInRunResultFlag, false))
if err != nil {
return nil, err
}
return &InspectionRunResult{
Metadata: md,
ResultStore: inspectionDataStore,
}, nil
}
func (i *InspectionTaskRunner) Metadata() (map[string]any, error) {
if i.runner == nil {
return nil, fmt.Errorf("this task is not yet started")
}
md, err := metadata.GetSerializableSubsetMapFromMetadataSet(i.metadata, filter.NewEnabledFilter(metadata.LabelKeyIncludedInRunResultFlag, false))
if err != nil {
return nil, err
}
return md, nil
}
func (i *InspectionTaskRunner) DryRun(ctx context.Context, req *inspection_task.InspectionRequest) (*InspectionDryRunResult, error) {
slog.DebugContext(ctx, "starting resolving task graph")
runnableTaskGraph, err := i.resolveTaskGraph()
if err != nil {
slog.ErrorContext(ctx, err.Error())
return nil, err
}
slog.DebugContext(ctx, "end resolving task graph")
runner, err := task.NewLocalRunner(runnableTaskGraph)
if err != nil {
return nil, err
}
runCtx := i.withRunContextValues(ctx, inspection_task_interface.TaskModeDryRun, req.Values)
dryrunMetadata := i.generateMetadataForDryRun(runCtx, &header.Header{}, runnableTaskGraph)
runCtx = khictx.WithValue(runCtx, inspection_task_contextkey.InspectionRunMetadata, dryrunMetadata)
err = runner.Run(runCtx)
if err != nil {
return nil, err
}
<-runner.Wait()
_, err = runner.Result()
if err != nil {
slog.ErrorContext(runCtx, err.Error())
return nil, err
}
md, err := metadata.GetSerializableSubsetMapFromMetadataSet(dryrunMetadata, filter.NewEnabledFilter(metadata.LabelKeyIncludedInDryRunResultFlag, false))
if err != nil {
return nil, err
}
return &InspectionDryRunResult{
Metadata: md,
}, nil
}
func (i *InspectionTaskRunner) MakeLoggers(ctx context.Context, minLevel slog.Level, m *typedmap.ReadonlyTypedMap, tasks []task.UntypedTask) *logger.Logger {
logger := logger.NewLogger()
for _, def := range tasks {
taskCtx := khictx.WithValue(ctx, task_contextkey.TaskImplementationIDContextKey, def.UntypedID())
logger.MakeTaskLogger(taskCtx, minLevel)
}
return logger
}
func (i *InspectionTaskRunner) GetCurrentMetadata() (*typedmap.ReadonlyTypedMap, error) {
if i.metadata == nil {
return nil, fmt.Errorf("this task hasn't been started")
}
return i.metadata, nil
}
func (i *InspectionTaskRunner) Cancel() error {
if i.cancel == nil {
return fmt.Errorf("this task is not yet started")
}
if _, err := i.Result(); err == nil {
return fmt.Errorf("task %s is already finished", i.ID)
}
i.cancel()
return nil
}
func (i *InspectionTaskRunner) Wait() <-chan interface{} {
return i.runner.Wait()
}
func (i *InspectionTaskRunner) resolveTaskGraph() (*task.TaskSet, error) {
if i.featureTasks == nil || i.availableTasks == nil {
return nil, fmt.Errorf("this runner is not ready for resolving graph")
}
usedTasks := []task.UntypedTask{}
usedTasks = append(usedTasks, i.featureTasks.GetAll()...)
usedTasks = append(usedTasks, i.requiredTasks.GetAll()...)
initialTaskSet, err := task.NewTaskSet(usedTasks)
if err != nil {
return nil, err
}
set, err := initialTaskSet.ResolveTask(i.availableTasks)
if err != nil {
return nil, err
}
wrapped, err := set.WrapGraph(taskid.NewDefaultImplementationID[any](inspection_task.InspectionMainSubgraphName), []taskid.UntypedTaskReference{})
if err != nil {
return nil, err
}
// Add required pre process or post process for the subgraph
err = wrapped.Add(serializer.SerializeTask)
if err != nil {
return nil, err
}
return wrapped.ResolveTask(i.availableTasks)
}
func (i *InspectionTaskRunner) generateMetadataForDryRun(ctx context.Context, initHeader *header.Header, taskGraph *task.TaskSet) *typedmap.ReadonlyTypedMap {
writableMetadata := typedmap.NewTypedMap()
i.addCommonMetadata(ctx, writableMetadata, initHeader, taskGraph)
return writableMetadata.AsReadonly()
}
func (i *InspectionTaskRunner) generateMetadataForRun(ctx context.Context, initHeader *header.Header, taskGraph *task.TaskSet) *typedmap.ReadonlyTypedMap {
writableMetadata := typedmap.NewTypedMap()
i.addCommonMetadata(ctx, writableMetadata, initHeader, taskGraph)
return writableMetadata.AsReadonly()
}
func (i *InspectionTaskRunner) addCommonMetadata(ctx context.Context, writableMetadata *typedmap.TypedMap, initHeader *header.Header, taskGraph *task.TaskSet) {
typedmap.Set(writableMetadata, header.HeaderMetadataKey, initHeader)
typedmap.Set(writableMetadata, error_metadata.ErrorMessageSetMetadataKey, error_metadata.NewErrorMessageSet())
typedmap.Set(writableMetadata, form.FormFieldSetMetadataKey, form.NewFormFieldSet())
typedmap.Set(writableMetadata, query.QueryMetadataKey, query.NewQueryMetadata())
progressMeta := progress.NewProgress()
progressMeta.SetTotalTaskCount(len(task.Subset(taskGraph, filter.NewEnabledFilter(inspection_task.LabelKeyProgressReportable, false)).GetAll()))
typedmap.Set(writableMetadata, progress.ProgressMetadataKey, progressMeta)
taskGraphStr, err := taskGraph.DumpGraphviz()
if err != nil {
taskGraphStr = fmt.Sprintf("failed to generate task graph %v", err.Error())
}
typedmap.Set(writableMetadata, plan.InspectionPlanMetadataKey, plan.NewInspectionPlan(taskGraphStr))
i.MakeLoggers(ctx, getLogLevel(), writableMetadata.AsReadonly(), taskGraph.GetAll())
}
func generateRandomString() string {
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
randomid := make([]rune, 16)
for i := range randomid {
randomid[i] = letters[rand.Intn(len(letters))]
}
return string(randomid)
}
func getLogLevel() slog.Level {
if parameters.Debug.Verbose != nil && *parameters.Debug.Verbose {
return slog.LevelDebug
}
return slog.LevelInfo
}