internal/graph/graph_scanner.go (249 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
package graph
import (
"context"
"embed"
"io/fs"
"math"
"strings"
"sync"
"time"
"github.com/Azure/azqr/internal/models"
"github.com/Azure/azqr/internal/to"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
)
//go:embed aprl/azure-resources/**/**/*.yaml
//go:embed aprl/azure-resources/**/**/kql/*.kql
//go:embed aprl/azure-specialized-workloads/**/*.yaml
//go:embed aprl/azure-specialized-workloads/**/kql/*.kql
//go:embed azure-orphan-resources/**/*.yaml
//go:embed azure-orphan-resources/**/kql/*.kql
var embededFiles embed.FS
type (
AprlScanner struct {
scanType []ScanType
serviceScanners []models.IAzureScanner
filters *models.Filters
subscriptions map[string]string
}
ScanType string
)
const (
AprlScanType ScanType = "aprl/azure-resources"
OrphanScanType ScanType = "azure-orphan-resources"
)
// create a new APRL scanner
func NewAprlScanner(serviceScanners []models.IAzureScanner, filters *models.Filters, subscriptions map[string]string) AprlScanner {
return AprlScanner{
scanType: []ScanType{
AprlScanType,
OrphanScanType,
},
serviceScanners: serviceScanners,
filters: filters,
subscriptions: subscriptions,
}
}
// GetAprlRecommendations returns a map with all APRL recommendations
func (a AprlScanner) GetAprlRecommendations() map[string]map[string]models.AprlRecommendation {
recommendations := map[string]map[string]models.AprlRecommendation{}
for _, t := range a.scanType {
source := "APRL"
if t == OrphanScanType {
source = "AOR"
}
rs := a.getAprlRecommendations(string(t))
for t, r := range rs {
for _, r := range r {
if recommendations[t] == nil {
recommendations[t] = map[string]models.AprlRecommendation{}
}
r.Source = source
recommendations[t][r.RecommendationID] = r
}
}
}
return recommendations
}
func (a AprlScanner) getAprlRecommendations(path string) map[string]map[string]models.AprlRecommendation {
r := map[string]map[string]models.AprlRecommendation{}
fsys, err := fs.Sub(embededFiles, path)
if err != nil {
return nil
}
q := map[string]string{}
err = fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.IsDir() && strings.HasSuffix(path, ".kql") {
content, err := fs.ReadFile(fsys, path)
if err != nil {
return err
}
fileName := strings.TrimSuffix(d.Name(), ".kql")
q[fileName] = string(content)
}
return nil
})
if err != nil {
return nil
}
err = fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.IsDir() && strings.HasSuffix(path, ".yaml") {
content, err := fs.ReadFile(fsys, path)
if err != nil {
return err
}
var recommendations []models.AprlRecommendation
err = yaml.Unmarshal(content, &recommendations)
if err != nil {
return err
}
for _, recommendation := range recommendations {
t := strings.ToLower(recommendation.ResourceType)
if _, ok := r[t]; !ok {
r[t] = map[string]models.AprlRecommendation{}
}
if i, ok := q[recommendation.RecommendationID]; ok {
recommendation.GraphQuery = i
}
r[t][recommendation.RecommendationID] = recommendation
}
}
return nil
})
if err != nil {
return nil
}
return r
}
// AprlScan scans Azure resources using Azure Proactive Resiliency Library v2 (APRL)
func (a AprlScanner) Scan(ctx context.Context, cred azcore.TokenCredential) (map[string]map[string]models.AprlRecommendation, []models.AprlResult) {
recommendations := map[string]map[string]models.AprlRecommendation{}
results := []models.AprlResult{}
rules := []models.AprlRecommendation{}
graph := NewGraphQuery(cred)
// get APRL recommendations
aprl := a.GetAprlRecommendations()
for _, s := range a.serviceScanners {
for _, t := range s.ResourceTypes() {
models.LogResourceTypeScan(t)
gr := a.getGraphRules(t, aprl)
for _, r := range gr {
rules = append(rules, r)
}
for i, r := range gr {
if recommendations[strings.ToLower(t)] == nil {
recommendations[strings.ToLower(t)] = map[string]models.AprlRecommendation{}
}
recommendations[strings.ToLower(t)][i] = r
}
}
}
batches := int(math.Ceil(float64(len(rules)) / 12))
jobs := make(chan []models.AprlRecommendation, batches)
ch := make(chan []models.AprlResult, batches)
var wg sync.WaitGroup
// Start workers
numWorkers := 12 // Define the number of workers in the pool
for w := 0; w < numWorkers; w++ {
go a.worker(ctx, graph, a.subscriptions, jobs, ch, &wg)
}
wg.Add(batches)
batchSize := 12
for i := 0; i < len(rules); i += batchSize {
j := i + batchSize
if j > len(rules) {
j = len(rules)
}
jobs <- rules[i:j]
// Staggering queries to avoid throttling. Max 15 queries each 5 seconds.
// https://learn.microsoft.com/en-us/azure/governance/resource-graph/concepts/guidance-for-throttled-requests#staggering-queries
time.Sleep(5 * time.Second)
}
// Wait for all workers to finish
close(jobs)
wg.Wait()
for i := 0; i < batches; i++ {
res := <-ch
for _, r := range res {
if a.filters.Azqr.IsServiceExcluded(r.ResourceID) {
continue
}
results = append(results, r)
}
}
return recommendations, results
}
func (a *AprlScanner) worker(ctx context.Context, graph *GraphQuery, subscriptions map[string]string, jobs <-chan []models.AprlRecommendation, results chan<- []models.AprlResult, wg *sync.WaitGroup) {
for r := range jobs {
res, err := a.graphScan(ctx, graph, r, subscriptions)
if err != nil {
log.Fatal().Err(err).Msg("Failed to scan")
}
results <- res
wg.Done()
}
}
func (a AprlScanner) graphScan(ctx context.Context, graphClient *GraphQuery, rules []models.AprlRecommendation, subscriptions map[string]string) ([]models.AprlResult, error) {
results := []models.AprlResult{}
subs := make([]*string, 0, len(subscriptions))
for s := range subscriptions {
subs = append(subs, &s)
}
sentQueries := 0
for _, rule := range rules {
if rule.GraphQuery != "" {
result := graphClient.Query(ctx, rule.GraphQuery, subs)
if result.Data != nil {
for _, row := range result.Data {
m := row.(map[string]interface{})
log.Debug().Msg(rule.GraphQuery)
// Check if "id" is present in the map
if _, ok := m["id"]; !ok {
log.Warn().Msgf("Skipping result: 'id' field is missing in the response for recommendation: %s", rule.RecommendationID)
break
}
subscription := models.GetSubscriptionFromResourceID(m["id"].(string))
subscriptionName, ok := subscriptions[subscription]
if !ok {
subscriptionName = ""
}
results = append(results, models.AprlResult{
RecommendationID: rule.RecommendationID,
Category: models.RecommendationCategory(rule.Category),
Recommendation: rule.Recommendation,
ResourceType: rule.ResourceType,
LongDescription: rule.LongDescription,
PotentialBenefits: rule.PotentialBenefits,
Impact: models.RecommendationImpact(rule.Impact),
Name: to.String(m["name"]),
ResourceID: to.String(m["id"]),
SubscriptionID: subscription,
SubscriptionName: subscriptionName,
ResourceGroup: models.GetResourceGroupFromResourceID(m["id"].(string)),
Tags: to.String(m["tags"]),
Param1: to.String(m["param1"]),
Param2: to.String(m["param2"]),
Param3: to.String(m["param3"]),
Param4: to.String(m["param4"]),
Param5: to.String(m["param5"]),
Learn: rule.LearnMoreLink[0].Url,
AutomationAvailable: rule.AutomationAvailable,
Source: rule.Source,
})
}
}
sentQueries++
if sentQueries == 2 {
// Staggering queries to avoid throttling. Max 10 queries each 5 seconds.
// https://learn.microsoft.com/en-us/azure/governance/resource-graph/concepts/guidance-for-throttled-requests#staggering-queries
time.Sleep(1 * time.Second)
}
}
}
return results, nil
}
func (a AprlScanner) getGraphRules(service string, aprl map[string]map[string]models.AprlRecommendation) map[string]models.AprlRecommendation {
r := map[string]models.AprlRecommendation{}
if i, ok := aprl[strings.ToLower(service)]; ok {
for _, recommendation := range i {
if a.filters.Azqr.IsRecommendationExcluded(recommendation.RecommendationID) ||
strings.Contains(recommendation.GraphQuery, "cannot-be-validated-with-arg") ||
strings.Contains(recommendation.GraphQuery, "under-development") ||
strings.Contains(recommendation.GraphQuery, "under development") {
continue
}
r[recommendation.RecommendationID] = recommendation
}
}
return r
}