bucket-utils/awsutils.go (145 lines of code) (raw):

package bucketutils import ( "context" "fmt" "strings" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/cloudformation" cfnTypes "github.com/aws/aws-sdk-go-v2/service/cloudformation/types" "github.com/aws/aws-sdk-go-v2/service/s3" s3Types "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/aws-sdk-go-v2/service/securityhub" "github.com/guardian/fsbp-tools/fsbp-fix/common" ) func findFailingBuckets(ctx context.Context, securityHubClient *securityhub.Client, bucketCount int32) ([]string, error) { controlId := "S3.8" findings, err := common.ReturnFindings(ctx, securityHubClient, controlId, bucketCount) if err != nil { return nil, err } findingsArr := findings.Findings var bucketsToBlock []string for _, finding := range findingsArr { for _, resource := range finding.Resources { bucketsToBlock = append(bucketsToBlock, strings.TrimPrefix(*resource.Id, "arn:aws:s3:::")) } } return bucketsToBlock, nil } func getAllStackSummaries(ctx context.Context, cfnClient *cloudformation.Client) ([]cfnTypes.StackSummary, error) { var allStackSummaries []cfnTypes.StackSummary firstStacks, err := cfnClient.ListStacks(ctx, &cloudformation.ListStacksInput{}) if err != nil { return nil, err } allStackSummaries = append(allStackSummaries, firstStacks.StackSummaries...) var nextToken = firstStacks.NextToken for nextToken != nil { stacks, err := cfnClient.ListStacks(ctx, &cloudformation.ListStacksInput{NextToken: nextToken}) if err != nil { return nil, err } allStackSummaries = append(allStackSummaries, stacks.StackSummaries...) nextToken = stacks.NextToken } fmt.Println("Found " + fmt.Sprint(len(allStackSummaries)) + " stacks in account.") return allStackSummaries, nil } func FindBucketsInStack(summaries []cfnTypes.StackResourceSummary, stackName string) []string { var buckets []string for _, resource := range summaries { if *resource.ResourceType == "AWS::S3::Bucket" { buckets = append(buckets, *resource.PhysicalResourceId) } } if len(buckets) > 0 { fmt.Printf("\nStack: %s - Buckets: %v", stackName, buckets) } return buckets } func getAllStackResources(ctx context.Context, cfnClient *cloudformation.Client, stackName string) ([]cfnTypes.StackResourceSummary, error) { var allStackResources []cfnTypes.StackResourceSummary firstResources, err := cfnClient.ListStackResources(ctx, &cloudformation.ListStackResourcesInput{StackName: &stackName}) if err != nil { return nil, err } allStackResources = append(allStackResources, firstResources.StackResourceSummaries...) var nextToken = firstResources.NextToken for nextToken != nil { resources, err := cfnClient.ListStackResources(ctx, &cloudformation.ListStackResourcesInput{StackName: &stackName, NextToken: nextToken}) if err != nil { return nil, err } allStackResources = append(allStackResources, resources.StackResourceSummaries...) nextToken = resources.NextToken } return allStackResources, nil } func listBucketsInStacks(ctx context.Context, cfnClient *cloudformation.Client) []string { allStackSummaries, _ := getAllStackSummaries(ctx, cfnClient) var bucketsInAStack []string for _, stack := range allStackSummaries { if stack.StackStatus != cfnTypes.StackStatusDeleteComplete { stackResourceSummaries, _ := getAllStackResources(ctx, cfnClient, *stack.StackName) buckets := FindBucketsInStack(stackResourceSummaries, *stack.StackName) bucketsInAStack = append(bucketsInAStack, buckets...) } } fmt.Println("") //Tidy up the log output return bucketsInAStack } func FindBucketsToBlock(ctx context.Context, securityHubClient *securityhub.Client, s3Client *s3.Client, cfnClient *cloudformation.Client, bucketCount int32, exclusions []string) ([]string, error) { failingBuckets, err := findFailingBuckets(ctx, securityHubClient, bucketCount) if err != nil { return nil, err } failingBucketCount := len(failingBuckets) excludedBuckets := append(listBucketsInStacks(ctx, cfnClient), exclusions...) fmt.Println("\nBuckets to exclude:") bucketsToBlock := common.Complement(failingBuckets, excludedBuckets) bucketsToBlockCount := len(bucketsToBlock) bucketsToSkipCount := failingBucketCount - bucketsToBlockCount fmt.Println("\nBlocking the following buckets:") for idx, bucket := range bucketsToBlock { fmt.Println(idx+1, bucket) } fmt.Print("\n") fmt.Println(failingBucketCount, "failing buckets found.") fmt.Println(bucketsToBlockCount, "to block, and", bucketsToSkipCount, "to skip.") return bucketsToBlock, nil } func blockPublicAccess(ctx context.Context, s3Client *s3.Client, name string) (*s3.PutPublicAccessBlockOutput, error) { resp, err := s3Client.PutPublicAccessBlock(ctx, &s3.PutPublicAccessBlockInput{ Bucket: aws.String(name), PublicAccessBlockConfiguration: &s3Types.PublicAccessBlockConfiguration{ BlockPublicAcls: aws.Bool(true), IgnorePublicAcls: aws.Bool(true), BlockPublicPolicy: aws.Bool(true), RestrictPublicBuckets: aws.Bool(true), }, }) if err != nil { return resp, err } fmt.Println("Public access blocked for bucket: " + name) return resp, nil } func BlockBuckets(ctx context.Context, s3Client *s3.Client, bucketsToBlock []string, execute bool) { if execute { userConfirmed := common.UserConfirmation() if userConfirmed { for _, name := range bucketsToBlock { _, err := blockPublicAccess(ctx, s3Client, name) if err != nil { fmt.Println("Error blocking public access: " + err.Error()) } } fmt.Println("Public access blocked for all buckets. Please note it may take 24 hours for SecurityHub to update.") } else { fmt.Println("Exiting without blocking public access.") } } else { fmt.Println("\nSkipping execution.") fmt.Println("Re-run with flag -execute to block access.") } }