tool/clean/clean_security_group/clean_security_group.go (356 lines of code) (raw):
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT
package main
import (
"context"
"flag"
"fmt"
"log"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/amazon-cloudwatch-agent/tool/clean"
)
type ec2Client interface {
DescribeSecurityGroups(ctx context.Context, params *ec2.DescribeSecurityGroupsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSecurityGroupsOutput, error)
DescribeNetworkInterfaces(ctx context.Context, params *ec2.DescribeNetworkInterfacesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeNetworkInterfacesOutput, error)
DeleteSecurityGroup(ctx context.Context, params *ec2.DeleteSecurityGroupInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error)
RevokeSecurityGroupIngress(ctx context.Context, params *ec2.RevokeSecurityGroupIngressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupIngressOutput, error)
RevokeSecurityGroupEgress(ctx context.Context, params *ec2.RevokeSecurityGroupEgressInput, optFns ...func(*ec2.Options)) (*ec2.RevokeSecurityGroupEgressOutput, error)
}
const (
SecurityGroupProcessChanSize = 500
)
// Config holds the application configuration
type Config struct {
ageThreshold time.Duration
numWorkers int
exceptionList []string
dryRun bool
skipVpcSGs bool
skipWithRules bool
}
// Global configuration
var (
cfg Config
)
func init() {
// Set default configuration
cfg = Config{
ageThreshold: 1 * clean.KeepDurationOneDay,
numWorkers: 30,
exceptionList: []string{"default"},
dryRun: true,
skipVpcSGs: false,
skipWithRules: false,
}
}
func main() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// Parse command line flags
flag.BoolVar(&cfg.dryRun, "dry-run", true, "Enable dry-run mode (no actual deletion)")
flag.DurationVar(&cfg.ageThreshold, "age", 1*clean.KeepDurationOneDay, "Age threshold for security groups (e.g. 24h)")
flag.BoolVar(&cfg.skipVpcSGs, "skip-vpc", false, "Skip security groups associated with VPCs")
flag.BoolVar(&cfg.skipWithRules, "skip-with-rules", false, "Skip security groups that have ingress or egress rules")
flag.Parse()
// Load AWS configuration
awsCfg, err := loadAWSConfig(ctx)
if err != nil {
log.Fatalf("Error loading AWS config: %v", err)
}
// Create EC2 client
client := ec2.NewFromConfig(awsCfg)
log.Printf("🔍 Searching for unused Security Groups older than %v in %s region\n",
cfg.ageThreshold, awsCfg.Region)
// Delete old security groups
deletedGroups, err := deleteUnusedSecurityGroups(ctx, client)
if err != nil {
log.Printf("Error deleting security groups: %v", err)
}
log.Printf("Total security groups deleted: %d", len(deletedGroups))
}
func loadAWSConfig(ctx context.Context) (aws.Config, error) {
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return aws.Config{}, fmt.Errorf("loading AWS config: %w", err)
}
cfg.RetryMode = aws.RetryModeAdaptive
return cfg, nil
}
func deleteUnusedSecurityGroups(ctx context.Context, client ec2Client) ([]string, error) {
var (
wg sync.WaitGroup
deletedSecurityGroups []string
foundSecurityGroupChan = make(chan types.SecurityGroup, SecurityGroupProcessChanSize)
deletedSecurityGroupChan = make(chan string, SecurityGroupProcessChanSize)
handlerWg sync.WaitGroup
)
// Start worker pool
log.Printf("👷 Creating %d workers\n", cfg.numWorkers)
for i := 0; i < cfg.numWorkers; i++ {
wg.Add(1)
w := worker{
id: i,
wg: &wg,
incomingSecurityGroupChan: foundSecurityGroupChan,
deletedSecurityGroupChan: deletedSecurityGroupChan,
}
go w.processSecurityGroup(ctx, client)
}
// Start handler with its own WaitGroup
handlerWg.Add(1)
go func() {
handleDeletedSecurityGroups(&deletedSecurityGroups, deletedSecurityGroupChan)
handlerWg.Done()
}()
// Process security groups in batches
if err := fetchAndProcessSecurityGroups(ctx, client, foundSecurityGroupChan); err != nil {
log.Printf("Error processing security groups: %v", err)
return nil, err
}
close(foundSecurityGroupChan)
wg.Wait()
close(deletedSecurityGroupChan)
handlerWg.Wait()
return deletedSecurityGroups, nil
}
func handleDeletedSecurityGroups(deletedSecurityGroups *[]string, deletedSecurityGroupChan chan string) {
for securityGroupId := range deletedSecurityGroupChan {
*deletedSecurityGroups = append(*deletedSecurityGroups, securityGroupId)
log.Printf("🔍 Processed %d security groups so far\n", len(*deletedSecurityGroups))
}
}
type worker struct {
id int
wg *sync.WaitGroup
incomingSecurityGroupChan <-chan types.SecurityGroup
deletedSecurityGroupChan chan<- string
}
func (w *worker) processSecurityGroup(ctx context.Context, client ec2Client) {
defer w.wg.Done()
for {
select {
case securityGroup, ok := <-w.incomingSecurityGroupChan:
if !ok {
return
}
if err := w.handleSecurityGroup(ctx, client, securityGroup); err != nil {
log.Printf("Worker %d: Error processing security group: %v", w.id, err)
}
case <-ctx.Done():
log.Printf("Worker %d: Stopping due to context cancellation", w.id)
return
}
}
}
func (w *worker) handleSecurityGroup(ctx context.Context, client ec2Client, securityGroup types.SecurityGroup) error {
sgID := *securityGroup.GroupId
sgName := *securityGroup.GroupName
// Skip default security groups
if isDefaultSecurityGroup(securityGroup) {
log.Printf("⏭️ Worker %d: Skipping default security group: %s (%s)", w.id, sgID, sgName)
return nil
}
// Skip security groups in exception list
if isSecurityGroupException(securityGroup) {
log.Printf("⏭️ Worker %d: Skipping security group in exception list: %s (%s)", w.id, sgID, sgName)
return nil
}
// Check if security group is in use
isInUse, err := isSecurityGroupInUse(ctx, client, sgID)
if err != nil {
return fmt.Errorf("checking if security group is in use: %w", err)
}
if isInUse {
log.Printf("⏭️ Worker %d: Security group is in use: %s (%s)", w.id, sgID, sgName)
return nil
}
// Check if security group has rules and we're configured to skip those
if cfg.skipWithRules && hasRules(securityGroup) {
log.Printf("⏭️ Worker %d: Skipping security group with rules: %s (%s)", w.id, sgID, sgName)
return nil
}
log.Printf("🚨 Worker %d: Found unused security group: %s (%s)", w.id, sgID, sgName)
// Clean up any rules before deletion
if hasRules(securityGroup) {
if err := cleanSecurityGroupRules(ctx, client, securityGroup); err != nil {
return fmt.Errorf("cleaning security group rules: %w", err)
}
}
w.deletedSecurityGroupChan <- sgID
if cfg.dryRun {
log.Printf("🛑 Dry-Run: Would delete security group: %s (%s)", sgID, sgName)
return nil
}
return deleteSecurityGroup(ctx, client, sgID)
}
func deleteSecurityGroup(ctx context.Context, client ec2Client, securityGroupID string) error {
_, err := client.DeleteSecurityGroup(ctx, &ec2.DeleteSecurityGroupInput{
GroupId: aws.String(securityGroupID),
})
if err != nil {
return fmt.Errorf("deleting security group %s: %w", securityGroupID, err)
}
log.Printf("✅ Deleted security group: %s", securityGroupID)
return nil
}
func cleanSecurityGroupRules(ctx context.Context, client ec2Client, securityGroup types.SecurityGroup) error {
sgID := *securityGroup.GroupId
// Get fresh security group data in one call
describeOutput, err := client.DescribeSecurityGroups(ctx, &ec2.DescribeSecurityGroupsInput{
GroupIds: []string{sgID},
})
if err != nil {
return fmt.Errorf("describing security group %s: %w", sgID, err)
}
if len(describeOutput.SecurityGroups) == 0 {
return fmt.Errorf("security group %s not found", sgID)
}
sg := describeOutput.SecurityGroups[0]
// Handle both ingress and egress rules concurrently
var wg sync.WaitGroup
var ingressErr, egressErr error
if len(sg.IpPermissions) > 0 {
if cfg.dryRun {
log.Printf("🛑 Dry-Run: Would revoke %d ingress rules from security group: %s",
len(sg.IpPermissions), sgID)
} else {
wg.Add(1)
go func() {
defer wg.Done()
_, err := client.RevokeSecurityGroupIngress(ctx, &ec2.RevokeSecurityGroupIngressInput{
GroupId: aws.String(sgID),
IpPermissions: sg.IpPermissions,
})
if err != nil {
ingressErr = fmt.Errorf("revoking ingress rules: %w", err)
} else {
log.Printf("✅ Revoked ingress rules from security group: %s", sgID)
}
}()
}
}
if len(sg.IpPermissionsEgress) > 0 {
if cfg.dryRun {
log.Printf("🛑 Dry-Run: Would revoke %d egress rules from security group: %s",
len(sg.IpPermissionsEgress), sgID)
} else {
wg.Add(1)
go func() {
defer wg.Done()
_, err := client.RevokeSecurityGroupEgress(ctx, &ec2.RevokeSecurityGroupEgressInput{
GroupId: aws.String(sgID),
IpPermissions: sg.IpPermissionsEgress,
})
if err != nil {
egressErr = fmt.Errorf("revoking egress rules: %w", err)
} else {
log.Printf("✅ Revoked egress rules from security group: %s", sgID)
}
}()
}
}
wg.Wait()
if ingressErr != nil {
return ingressErr
}
if egressErr != nil {
return egressErr
}
return nil
}
func fetchAndProcessSecurityGroups(ctx context.Context, client ec2Client,
securityGroupChan chan<- types.SecurityGroup) error {
maxResults := int32(100) // AWS maximum allowed
var nextToken *string
describeCount := 0
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
output, err := client.DescribeSecurityGroups(ctx, &ec2.DescribeSecurityGroupsInput{
MaxResults: aws.Int32(maxResults),
NextToken: nextToken,
})
if err != nil {
return fmt.Errorf("describing security groups: %w", err)
}
log.Printf("🔍 Described %d times | Found %d security groups\n", describeCount, len(output.SecurityGroups))
// Process in batches with context awareness
for _, securityGroup := range output.SecurityGroups {
select {
case securityGroupChan <- securityGroup:
case <-ctx.Done():
return ctx.Err()
}
}
if output.NextToken == nil {
break
}
nextToken = output.NextToken
describeCount++
}
}
return nil
}
func isSecurityGroupInUse(ctx context.Context, client ec2Client, securityGroupID string) (bool, error) {
// Use a channel to handle concurrent checks
resultChan := make(chan bool, 2)
errChan := make(chan error, 2)
// Check network interfaces concurrently
go func() {
output, err := client.DescribeNetworkInterfaces(ctx, &ec2.DescribeNetworkInterfacesInput{
Filters: []types.Filter{
{
Name: aws.String("group-id"),
Values: []string{securityGroupID},
},
},
})
if err != nil {
errChan <- fmt.Errorf("describing network interfaces: %w", err)
return
}
resultChan <- len(output.NetworkInterfaces) > 0
}()
// Check security group references concurrently
go func() {
output, err := client.DescribeSecurityGroups(ctx, &ec2.DescribeSecurityGroupsInput{})
if err != nil {
errChan <- fmt.Errorf("describing security groups: %w", err)
return
}
for _, sg := range output.SecurityGroups {
if *sg.GroupId == securityGroupID {
continue
}
// Check both ingress and egress rules
if isReferencedInRules(sg.IpPermissions, securityGroupID) ||
isReferencedInRules(sg.IpPermissionsEgress, securityGroupID) {
resultChan <- true
return
}
}
resultChan <- false
}()
// Wait for both checks
for i := 0; i < 2; i++ {
select {
case err := <-errChan:
return false, err
case isUsed := <-resultChan:
if isUsed {
return true, nil
}
case <-ctx.Done():
return false, ctx.Err()
}
}
return false, nil
}
func isReferencedInRules(permissions []types.IpPermission, securityGroupID string) bool {
for _, permission := range permissions {
for _, userIdGroupPair := range permission.UserIdGroupPairs {
if userIdGroupPair.GroupId != nil && *userIdGroupPair.GroupId == securityGroupID {
return true
}
}
}
return false
}
func isDefaultSecurityGroup(securityGroup types.SecurityGroup) bool {
return *securityGroup.GroupName == "default"
}
func isSecurityGroupException(securityGroup types.SecurityGroup) bool {
sgName := *securityGroup.GroupName
for _, exception := range cfg.exceptionList {
if strings.Contains(sgName, exception) {
return true
}
}
return false
}
func hasRules(securityGroup types.SecurityGroup) bool {
return len(securityGroup.IpPermissions) > 0 || len(securityGroup.IpPermissionsEgress) > 0
}