pkg/source/gcp/query/queryutil/parallel.go (136 lines of code) (raw):
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package queryutil
import (
"context"
"errors"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/GoogleCloudPlatform/khi/pkg/common/worker"
"github.com/GoogleCloudPlatform/khi/pkg/inspection/metadata/progress"
"github.com/GoogleCloudPlatform/khi/pkg/log"
"github.com/GoogleCloudPlatform/khi/pkg/log/structure"
"github.com/GoogleCloudPlatform/khi/pkg/log/structure/adapter"
"github.com/GoogleCloudPlatform/khi/pkg/parser/yaml/yamlutil"
"github.com/GoogleCloudPlatform/khi/pkg/source/gcp/api"
gcp_log "github.com/GoogleCloudPlatform/khi/pkg/source/gcp/log"
)
type ParallelQueryWorker struct {
workerCount int
baseQuery string
startTime time.Time
endTime time.Time
apiClient api.GCPClient
pool *worker.Pool
}
func NewParallelQueryWorker(pool *worker.Pool, apiClient api.GCPClient, baseQuery string, startTime time.Time, endTime time.Time, workerCount int) *ParallelQueryWorker {
return &ParallelQueryWorker{
baseQuery: baseQuery,
startTime: startTime,
endTime: endTime,
workerCount: workerCount,
apiClient: apiClient,
pool: pool,
}
}
func (p *ParallelQueryWorker) Query(ctx context.Context, readerFactory *structure.ReaderFactory, resourceNames []string, progress *progress.TaskProgress) ([]*log.LogEntity, error) {
timeSegments := divideTimeSegments(p.startTime, p.endTime, p.workerCount)
percentages := make([]float32, p.workerCount)
logSink := make(chan *log.LogEntity)
logEntries := []*log.LogEntity{}
wg := sync.WaitGroup{}
queryStartTime := time.Now()
threadCount := atomic.Int32{}
threadCount.Add(1)
go func() {
cancellable, cancel := context.WithCancel(ctx)
go func() {
for {
select {
case <-cancellable.Done():
return
case <-time.After(time.Second):
currentTime := time.Now()
speed := float64(len(logEntries)) / currentTime.Sub(queryStartTime).Seconds()
s := float32(0)
for _, p := range percentages {
s += p
}
progressRatio := s / float32(len(percentages))
progress.Update(progressRatio, fmt.Sprintf("%.2f lps(concurrency %d)", speed, threadCount.Load()))
}
}
}()
for logEntry := range logSink {
logEntries = append(logEntries, logEntry)
}
cancel()
}()
cancellableCtx, cancel := context.WithCancelCause(ctx)
defer cancel(errors.New("query completed"))
for i := 0; i < len(timeSegments)-1; i++ {
workerIndex := i
begin := timeSegments[i]
end := timeSegments[i+1]
includeEnd := i == len(timeSegments)-1
query := fmt.Sprintf("%s\n%s", p.baseQuery, TimeRangeQuerySection(begin, end, includeEnd))
subLogSink := make(chan any)
wg.Add(1)
p.pool.Run(func() {
defer wg.Done()
go func() {
threadCount.Add(1)
err := p.apiClient.ListLogEntries(cancellableCtx, resourceNames, query, subLogSink)
if err != nil && !errors.Is(err, context.Canceled) {
slog.WarnContext(cancellableCtx, fmt.Sprintf("query thread failed with an error\n%s", err))
cancel(err)
}
}()
for logEntryAny := range subLogSink {
yamlString, err := yamlutil.MarshalToYamlString(logEntryAny)
if err != nil {
slog.WarnContext(ctx, "failed to parse a log as YAML. Skipping.")
continue
}
logReader, err := readerFactory.NewReader(adapter.Yaml(yamlString))
if err != nil {
slog.WarnContext(ctx, fmt.Sprintf("failed to create reader for log entry\n%s", err))
continue
}
commonLogFieldCache := log.NewCachedLogFieldExtractor(gcp_log.GCPCommonFieldExtractor{})
commonLogFieldCache.SetLogBodyCacheDirect(yamlString)
logEntry := log.NewLogEntity(logReader, commonLogFieldCache)
percentages[workerIndex] = float32(logEntry.Timestamp().Sub(begin)) / float32(end.Sub(begin))
logSink <- logEntry
}
percentages[workerIndex] = 1
threadCount.Add(-1)
})
if errors.Is(cancellableCtx.Err(), context.Canceled) {
break
}
// To avoid being rate limited by accessing all at once, the access timing is shifted by 3000ms.
<-time.After(time.Second * 3)
}
wg.Wait()
close(logSink)
err := context.Cause(cancellableCtx)
if err != nil {
cancel(err)
return nil, err
}
cancel(nil)
return logEntries, nil
}
func divideTimeSegments(startTime time.Time, endTime time.Time, count int) []time.Time {
duration := endTime.Sub(startTime)
sub_interval_duration := duration / time.Duration(count)
sub_intervals := make([]time.Time, count+1)
current_start := startTime
for i := range sub_intervals {
sub_intervals[i] = current_start
current_start = current_start.Add(sub_interval_duration)
}
sub_intervals[len(sub_intervals)-1] = endTime
return sub_intervals
}