internal/vulnerability/snapshot.go (133 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package vulnerability
import (
"context"
"iter"
"sync"
"time"
"github.com/elastic/cloudbeat/internal/infra/clog"
"github.com/elastic/cloudbeat/internal/resources/providers/awslib/ec2"
)
const (
backgroundDeleteWorkers = 3
backgroundDeleteTimeout = 2 * 24 * time.Hour
)
type snapshotCreatorDeleter interface {
CreateSnapshots(ctx context.Context, ins *ec2.Ec2Instance) ([]ec2.EBSSnapshot, error)
DeleteSnapshot(ctx context.Context, snapshot ec2.EBSSnapshot) error
IterOwnedSnapshots(ctx context.Context, before time.Time) iter.Seq[ec2.EBSSnapshot]
}
type SnapshotManager struct {
lock sync.Mutex
snapshots map[string]ec2.EBSSnapshot
provider snapshotCreatorDeleter
logger *clog.Logger
}
func NewSnapshotManager(logger *clog.Logger, provider snapshotCreatorDeleter) *SnapshotManager {
return &SnapshotManager{
lock: sync.Mutex{},
snapshots: make(map[string]ec2.EBSSnapshot),
provider: provider,
logger: logger,
}
}
func (s *SnapshotManager) CreateSnapshots(ctx context.Context, ins *ec2.Ec2Instance) ([]ec2.EBSSnapshot, error) {
snaps, err := s.provider.CreateSnapshots(ctx, ins)
if err != nil {
return nil, err
}
s.lock.Lock()
defer s.lock.Unlock()
for _, snap := range snaps {
s.snapshots[snap.SnapshotId] = snap
}
return snaps, err
}
func (s *SnapshotManager) DeleteSnapshot(ctx context.Context, snapshot ec2.EBSSnapshot) {
runWithGrace(ctx, shutdownGracePeriod, func(ctx context.Context) {
s.delete(ctx, snapshot, "DeleteSnapshot")
})
s.lock.Lock()
defer s.lock.Unlock()
delete(s.snapshots, snapshot.SnapshotId)
}
func (s *SnapshotManager) Cleanup(ctx context.Context) {
s.lock.Lock()
defer s.lock.Unlock()
runWithGrace(ctx, shutdownGracePeriod, func(ctx context.Context) {
var wg sync.WaitGroup
defer wg.Wait()
for _, snap := range s.snapshots {
wg.Add(1)
go func() {
defer wg.Done()
s.delete(ctx, snap, "Cleanup")
}()
}
})
clear(s.snapshots)
}
func (s *SnapshotManager) DeleteOldSnapshots(ctx context.Context) {
var wg sync.WaitGroup
defer wg.Wait()
ch := newContextualChan[ec2.EBSSnapshot]()
defer ch.Close()
wg.Add(backgroundDeleteWorkers)
for range backgroundDeleteWorkers {
go func() {
defer wg.Done()
for {
snap, ok := ch.Read(ctx)
if !ok {
return
}
s.delete(ctx, snap, "DeleteOldSnapshots")
}
}()
}
for snapshot := range s.provider.IterOwnedSnapshots(ctx, time.Now().Add(-backgroundDeleteTimeout)) {
if !ch.Write(ctx, snapshot) {
return
}
}
}
func (s *SnapshotManager) delete(ctx context.Context, snapshot ec2.EBSSnapshot, message string) {
s.logger.Infof("VulnerabilityScanner.manager.%s %s", message, snapshot.SnapshotId)
err := s.provider.DeleteSnapshot(ctx, snapshot)
if err != nil {
s.logger.Errorf("VulnerabilityScanner.manager.%s %s error: %s", message, snapshot.SnapshotId, err)
}
}
// runWithGrace runs the given function with the given context but allowing an extra given grace period after the
// context is cancelled.
func runWithGrace(ctx context.Context, grace time.Duration, f func(ctx context.Context)) {
// WithoutCancel: disassociate the cancellation of ctx from the cancellation of newCtx
// WithCancel: add a cancellation mechanism to the new context
newCtx, cancel := context.WithCancel(context.WithoutCancel(ctx))
defer cancel() // in all cases, call cancel after the callback is finished
stop := context.AfterFunc(ctx, func() { // called after original context is cancelled
time.AfterFunc(grace, cancel) // wait for grace period and then cancel newCtx
})
defer stop() // if the callback finishes in time, stop the AfterFunc
f(newCtx) // finally, call the actual callback!
}
type contextualChan[T any] struct {
ch chan T
}
func newContextualChan[T any]() contextualChan[T] {
return contextualChan[T]{ch: make(chan T)}
}
func (s contextualChan[T]) Write(ctx context.Context, t T) bool {
select {
case <-ctx.Done():
return false
case s.ch <- t:
return true
}
}
func (s contextualChan[T]) Read(ctx context.Context) (T, bool) {
select {
case t, ok := <-s.ch:
return t, ok
case <-ctx.Done():
return *new(T), false
}
}
func (s contextualChan[T]) Close() {
close(s.ch)
}