internal/vulnerability/snapshot.go (133 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package vulnerability import ( "context" "iter" "sync" "time" "github.com/elastic/cloudbeat/internal/infra/clog" "github.com/elastic/cloudbeat/internal/resources/providers/awslib/ec2" ) const ( backgroundDeleteWorkers = 3 backgroundDeleteTimeout = 2 * 24 * time.Hour ) type snapshotCreatorDeleter interface { CreateSnapshots(ctx context.Context, ins *ec2.Ec2Instance) ([]ec2.EBSSnapshot, error) DeleteSnapshot(ctx context.Context, snapshot ec2.EBSSnapshot) error IterOwnedSnapshots(ctx context.Context, before time.Time) iter.Seq[ec2.EBSSnapshot] } type SnapshotManager struct { lock sync.Mutex snapshots map[string]ec2.EBSSnapshot provider snapshotCreatorDeleter logger *clog.Logger } func NewSnapshotManager(logger *clog.Logger, provider snapshotCreatorDeleter) *SnapshotManager { return &SnapshotManager{ lock: sync.Mutex{}, snapshots: make(map[string]ec2.EBSSnapshot), provider: provider, logger: logger, } } func (s *SnapshotManager) CreateSnapshots(ctx context.Context, ins *ec2.Ec2Instance) ([]ec2.EBSSnapshot, error) { snaps, err := s.provider.CreateSnapshots(ctx, ins) if err != nil { return nil, err } s.lock.Lock() defer s.lock.Unlock() for _, snap := range snaps { s.snapshots[snap.SnapshotId] = snap } return snaps, err } func (s *SnapshotManager) DeleteSnapshot(ctx context.Context, snapshot ec2.EBSSnapshot) { runWithGrace(ctx, shutdownGracePeriod, func(ctx context.Context) { s.delete(ctx, snapshot, "DeleteSnapshot") }) s.lock.Lock() defer s.lock.Unlock() delete(s.snapshots, snapshot.SnapshotId) } func (s *SnapshotManager) Cleanup(ctx context.Context) { s.lock.Lock() defer s.lock.Unlock() runWithGrace(ctx, shutdownGracePeriod, func(ctx context.Context) { var wg sync.WaitGroup defer wg.Wait() for _, snap := range s.snapshots { wg.Add(1) go func() { defer wg.Done() s.delete(ctx, snap, "Cleanup") }() } }) clear(s.snapshots) } func (s *SnapshotManager) DeleteOldSnapshots(ctx context.Context) { var wg sync.WaitGroup defer wg.Wait() ch := newContextualChan[ec2.EBSSnapshot]() defer ch.Close() wg.Add(backgroundDeleteWorkers) for range backgroundDeleteWorkers { go func() { defer wg.Done() for { snap, ok := ch.Read(ctx) if !ok { return } s.delete(ctx, snap, "DeleteOldSnapshots") } }() } for snapshot := range s.provider.IterOwnedSnapshots(ctx, time.Now().Add(-backgroundDeleteTimeout)) { if !ch.Write(ctx, snapshot) { return } } } func (s *SnapshotManager) delete(ctx context.Context, snapshot ec2.EBSSnapshot, message string) { s.logger.Infof("VulnerabilityScanner.manager.%s %s", message, snapshot.SnapshotId) err := s.provider.DeleteSnapshot(ctx, snapshot) if err != nil { s.logger.Errorf("VulnerabilityScanner.manager.%s %s error: %s", message, snapshot.SnapshotId, err) } } // runWithGrace runs the given function with the given context but allowing an extra given grace period after the // context is cancelled. func runWithGrace(ctx context.Context, grace time.Duration, f func(ctx context.Context)) { // WithoutCancel: disassociate the cancellation of ctx from the cancellation of newCtx // WithCancel: add a cancellation mechanism to the new context newCtx, cancel := context.WithCancel(context.WithoutCancel(ctx)) defer cancel() // in all cases, call cancel after the callback is finished stop := context.AfterFunc(ctx, func() { // called after original context is cancelled time.AfterFunc(grace, cancel) // wait for grace period and then cancel newCtx }) defer stop() // if the callback finishes in time, stop the AfterFunc f(newCtx) // finally, call the actual callback! } type contextualChan[T any] struct { ch chan T } func newContextualChan[T any]() contextualChan[T] { return contextualChan[T]{ch: make(chan T)} } func (s contextualChan[T]) Write(ctx context.Context, t T) bool { select { case <-ctx.Done(): return false case s.ch <- t: return true } } func (s contextualChan[T]) Read(ctx context.Context) (T, bool) { select { case t, ok := <-s.ch: return t, ok case <-ctx.Done(): return *new(T), false } } func (s contextualChan[T]) Close() { close(s.ch) }