pkg/source/gcp/query/query.go (151 lines of code) (raw):
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package query
import (
"context"
"fmt"
"log/slog"
"slices"
"strings"
"github.com/GoogleCloudPlatform/khi/pkg/common/khictx"
"github.com/GoogleCloudPlatform/khi/pkg/common/typedmap"
"github.com/GoogleCloudPlatform/khi/pkg/common/worker"
inspection_task_contextkey "github.com/GoogleCloudPlatform/khi/pkg/inspection/contextkey"
inspection_task_interface "github.com/GoogleCloudPlatform/khi/pkg/inspection/interface"
error_metadata "github.com/GoogleCloudPlatform/khi/pkg/inspection/metadata/error"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/metadata/progress"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/metadata/query"
inspection_task "github.com/GoogleCloudPlatform/khi/pkg/inspection/task"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/task/label"
"github.com/GoogleCloudPlatform/khi/pkg/log"
"github.com/GoogleCloudPlatform/khi/pkg/model/enum"
"github.com/GoogleCloudPlatform/khi/pkg/source/gcp/api"
"github.com/GoogleCloudPlatform/khi/pkg/source/gcp/query/queryutil"
gcp_task "github.com/GoogleCloudPlatform/khi/pkg/source/gcp/task"
gcp_taskid "github.com/GoogleCloudPlatform/khi/pkg/source/gcp/taskid"
"github.com/GoogleCloudPlatform/khi/pkg/task"
"github.com/GoogleCloudPlatform/khi/pkg/task/taskid"
)
const GKEQueryPrefix = gcp_task.GCPPrefix + "query/gke/"
// Query task will return @Skip when query builder decided to skip.
const SkipQueryBody = "@Skip"
type QueryGeneratorFunc = func(context.Context, inspection_task_interface.InspectionTaskMode) ([]string, error)
// DefaultResourceNamesGenerator returns the default resource names used for querying Cloud Logging.
type DefaultResourceNamesGenerator interface {
// GetDependentTasks returns the list of taks references needed for generating resource names.
GetDependentTasks() []taskid.UntypedTaskReference
// GenerateResourceNames returns the list of resource names.
GenerateResourceNames(ctx context.Context) ([]string, error)
}
type ProjectIDDefaultResourceNamesGenerator struct{}
// GenerateResourceNames implements DefaultResourceNamesGenerator.
func (p *ProjectIDDefaultResourceNamesGenerator) GenerateResourceNames(ctx context.Context) ([]string, error) {
projectID := task.GetTaskResult(ctx, gcp_task.InputProjectIdTaskID.Ref())
return []string{fmt.Sprintf("projects/%s", projectID)}, nil
}
// GetDependentTasks implements DefaultResourceNamesGenerator.
func (p *ProjectIDDefaultResourceNamesGenerator) GetDependentTasks() []taskid.UntypedTaskReference {
return []taskid.UntypedTaskReference{
gcp_task.InputProjectIdTaskID.Ref(),
}
}
var _ DefaultResourceNamesGenerator = (*ProjectIDDefaultResourceNamesGenerator)(nil)
var queryThreadPool = worker.NewPool(16)
func NewQueryGeneratorTask(taskId taskid.TaskImplementationID[[]*log.LogEntity], readableQueryName string, logType enum.LogType, dependencies []taskid.UntypedTaskReference, resourceNamesGenerator DefaultResourceNamesGenerator, generator QueryGeneratorFunc, sampleQuery string) task.Task[[]*log.LogEntity] {
return inspection_task.NewProgressReportableInspectionTask(taskId, append(
append(dependencies, resourceNamesGenerator.GetDependentTasks()...),
gcp_task.InputStartTimeTaskID.Ref(),
gcp_task.InputEndTimeTaskID.Ref(),
inspection_task.ReaderFactoryGeneratorTaskID.Ref(),
gcp_taskid.LoggingFilterResourceNameInputTaskID.Ref(),
), func(ctx context.Context, taskMode inspection_task_interface.InspectionTaskMode, progress *progress.TaskProgress) ([]*log.LogEntity, error) {
client, err := api.DefaultGCPClientFactory.NewClient()
if err != nil {
return nil, err
}
metadata := khictx.MustGetValue(ctx, inspection_task_contextkey.InspectionRunMetadata)
resourceNames := task.GetTaskResult(ctx, gcp_taskid.LoggingFilterResourceNameInputTaskID.Ref())
taskInput := khictx.MustGetValue(ctx, inspection_task_contextkey.InspectionTaskInput)
defaultResourceNames, err := resourceNamesGenerator.GenerateResourceNames(ctx)
if err != nil {
return nil, err
}
resourceNames.UpdateDefaultResourceNamesForQuery(taskId.ReferenceIDString(), defaultResourceNames)
queryResourceNamePair := resourceNames.GetResourceNamesForQuery(taskId.ReferenceIDString())
resourceNamesFromInput := defaultResourceNames
inputStr, found := taskInput[queryResourceNamePair.GetInputID()]
if found {
resourceNamesFromInput = strings.Split(inputStr.(string), " ")
resourceNamesList := []string{}
hadError := false
for _, resourceNameFromInput := range resourceNamesFromInput {
resourceNameWithoutSurroundingSpace := strings.TrimSpace(resourceNameFromInput)
err := api.ValidateResourceNameOnLogEntriesList(resourceNameWithoutSurroundingSpace)
if err != nil {
hadError = true
break
}
resourceNamesList = append(resourceNamesList, resourceNameWithoutSurroundingSpace)
}
if !hadError {
resourceNamesFromInput = resourceNamesList
}
}
readerFactory := task.GetTaskResult(ctx, inspection_task.ReaderFactoryGeneratorTaskID.Ref())
startTime := task.GetTaskResult(ctx, gcp_task.InputStartTimeTaskID.Ref())
endTime := task.GetTaskResult(ctx, gcp_task.InputEndTimeTaskID.Ref())
queryStrings, err := generator(ctx, taskMode)
if err != nil {
return nil, err
}
if len(queryStrings) == 0 {
slog.InfoContext(ctx, fmt.Sprintf("Query generator `%s` decided to skip.", taskId))
return []*log.LogEntity{}, nil
}
queryInfo, found := typedmap.Get(metadata, query.QueryMetadataKey)
if !found {
return nil, fmt.Errorf("query metadata was not found")
}
allLogs := []*log.LogEntity{}
for queryIndex, queryString := range queryStrings {
// Record query information in metadat a
readableQueryNameForQueryIndex := readableQueryName
if len(queryStrings) > 1 {
readableQueryNameForQueryIndex = fmt.Sprintf("%s-%d", readableQueryName, queryIndex)
}
finalQuery := fmt.Sprintf("%s\n%s", queryString, queryutil.TimeRangeQuerySection(startTime, endTime, true))
if len(finalQuery) > 20000 {
slog.WarnContext(ctx, fmt.Sprintf("Logging filter is exceeding Cloud Logging limitation 20000 charactors\n%s", finalQuery))
}
queryInfo.SetQuery(taskId.String(), readableQueryNameForQueryIndex, finalQuery)
// TODO: not to store whole logs on memory to avoid OOM
// Run query only when thetask mode is for running
if taskMode == inspection_task_interface.TaskModeRun {
worker := queryutil.NewParallelQueryWorker(queryThreadPool, client, queryString, startTime, endTime, 5)
queryLogs, queryErr := worker.Query(ctx, readerFactory, resourceNamesFromInput, progress)
if queryErr != nil {
errorMessageSet, found := typedmap.Get(metadata, error_metadata.ErrorMessageSetMetadataKey)
if !found {
return nil, fmt.Errorf("error message set metadata was not found")
}
if strings.HasPrefix(queryErr.Error(), "401:") {
errorMessageSet.AddErrorMessage(error_metadata.NewUnauthorizedErrorMessage())
}
// TODO: these errors are shown to frontend but it's not well implemented.
if strings.HasPrefix(queryErr.Error(), "403:") {
errorMessageSet.AddErrorMessage(&error_metadata.ErrorMessage{
ErrorId: 0,
Message: queryErr.Error(),
})
}
if strings.HasPrefix(queryErr.Error(), "404:") {
errorMessageSet.AddErrorMessage(&error_metadata.ErrorMessage{
ErrorId: 0,
Message: queryErr.Error(),
})
}
return nil, queryErr
}
allLogs = append(allLogs, queryLogs...)
}
}
if taskMode == inspection_task_interface.TaskModeRun {
slices.SortFunc(allLogs, func(a, b *log.LogEntity) int {
return int(a.Timestamp().Sub(b.Timestamp()))
})
for _, l := range allLogs {
l.LogType = logType
}
return allLogs, err
}
return []*log.LogEntity{}, err
}, label.NewQueryTaskLabelOpt(logType, sampleQuery))
}