sg/internal/engine/rego_query.go (231 lines of code) (raw):
package engine
import (
"context"
"fmt"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/sourcegraph/conc/iter"
"github.com/Azure/ShieldGuard/sg/internal/armtemplateparser"
"github.com/Azure/ShieldGuard/sg/internal/policy"
"github.com/Azure/ShieldGuard/sg/internal/result"
"github.com/Azure/ShieldGuard/sg/internal/source"
"github.com/Azure/ShieldGuard/sg/internal/utils"
)
type loadedConfiguration struct {
Name string
Configuration ast.Value
}
func loadSource(source source.Source, shouldParseArmTemplateDefaults bool) ([]loadedConfiguration, error) {
var rv []loadedConfiguration
configurations, err := source.ParsedConfigurations()
if err != nil {
return nil, err
}
for _, configuration := range configurations {
t := ast.NewTerm(configuration)
if shouldParseArmTemplateDefaults {
armtemplateparser.ParseArmTemplateDefaults(t)
}
rv = append(rv, loadedConfiguration{
Name: source.Name(),
Configuration: t.Value,
})
}
return rv, nil
}
// PackageMain is the name of the main package.
// To ease the usage, we will only use rules from main package.
const PackageMain = "main"
// RegoEngine is the OPA based query engine implementation.
type RegoEngine struct {
policyPackages []policy.Package
compiler *ast.Compiler
compilerKey string
limiter limiter
queryCache QueryCache
parseArmTemplateDefaults bool
}
var _ Queryer = (*RegoEngine)(nil)
func (engine *RegoEngine) Query(
ctx context.Context,
source source.Source,
opts ...*QueryOptions,
) (result.QueryResults, error) {
loadedConfigurations, err := loadSource(source, engine.parseArmTemplateDefaults)
if err != nil {
return result.QueryResults{}, fmt.Errorf("failed to load source: %w", err)
}
var aggregatedQueryResults result.QueryResults
for _, loadedConfiguration := range loadedConfigurations {
for _, policyPackage := range engine.policyPackages {
queryResult, err := engine.queryPackage(ctx, policyPackage, loadedConfiguration)
if err != nil {
return result.QueryResults{}, err
}
aggregatedQueryResults = aggregatedQueryResults.Merge(queryResult)
}
}
aggregatedQueryResults.Source = source
return aggregatedQueryResults, nil
}
func (engine *RegoEngine) queryPackage(
ctx context.Context,
policyPackage policy.Package,
loadedConfiguration loadedConfiguration,
) (result.QueryResults, error) {
// NOTE: because an rego query returns all failures for a given rule,
// even if the rule is repeated with different bodies. Therefore,
// we should only query the distinct rules. At the end, the total success
// rules should be the count of total rules minus the query results plus
// succeeded query results.
allRules := policyPackage.Rules()
distinctRules := make([]policy.Rule, 0, len(allRules))
rulesSet := make(map[string]struct{}, len(allRules))
for _, rule := range allRules {
primaryRuleKey := rule.Query()
if _, ok := rulesSet[primaryRuleKey]; ok {
// skip duplicate rules
continue
}
rulesSet[primaryRuleKey] = struct{}{}
distinctRules = append(distinctRules, rule)
}
mm := iter.Mapper[policy.Rule, result.QueryResults]{
MaxGoroutines: len(distinctRules),
}
queryResults, err := mm.MapErr(
distinctRules,
func(rulePtr *policy.Rule) (result.QueryResults, error) {
done := engine.limiter.acquire()
defer done()
rule := *rulePtr
rv := result.QueryResults{}
if rule.Namespace != PackageMain {
// we only care about rules in the main package
return rv, nil
}
if !rule.IsKind(policy.QueryKindWarn, policy.QueryKindDeny, policy.QueryKindViolation) {
// not a query rule
return rv, nil
}
if err := engine.queryRule(
ctx,
policyPackage, rule,
loadedConfiguration, &rv,
); err != nil {
return rv, fmt.Errorf("failed to query rule: %w", err)
}
return rv, nil
},
)
if err != nil {
return result.QueryResults{}, nil
}
queryResult := result.QueryResults{}
for _, qr := range queryResults {
queryResult = queryResult.Merge(qr)
}
resultsCount := queryResult.Successes + len(queryResult.Failures) + len(queryResult.Warnings) + len(queryResult.Exceptions)
if duplicatedRulesCount := len(allRules) - resultsCount; duplicatedRulesCount > 0 {
queryResult.Successes += duplicatedRulesCount
}
return queryResult, nil
}
func resolveRuleDocLinkFn(policyPackage policy.Package) func(policy.Rule) (string, error) {
// TODO(hbc): cache resolved doc link by rule
return func(rule policy.Rule) (string, error) {
return policy.ResolveRuleDocLink(policyPackage.Spec(), rule)
}
}
func (engine *RegoEngine) queryRule(
ctx context.Context,
policyPackage policy.Package,
policyRule policy.Rule,
loadedConfiguration loadedConfiguration,
queryResult *result.QueryResults,
) error {
resolveRuleDocLink := resolveRuleDocLinkFn(policyPackage)
// execute exception query
exceptionQuery := fmt.Sprintf("data.%s.exception[_][_] == %q", PackageMain, policyRule.Name)
exceptions, err := engine.executeOneQuery(ctx, loadedConfiguration.Configuration, exceptionQuery)
if err != nil {
return fmt.Errorf("failed to execute exception query (%q): %w", exceptionQuery, err)
}
exceptions = utils.Filter(exceptions, func(x result.Result) bool { return x.Passed() })
// execute query
// NOTE: even if the exception query returns true, we still execute the query
query := fmt.Sprintf("data.%s.%s", PackageMain, policyRule.Query())
results, err := engine.executeOneQuery(ctx, loadedConfiguration.Configuration, query)
if err != nil {
return fmt.Errorf("failed to execute query (%q): %w", query, err)
}
// excluded by at least one exception
if len(exceptions) > 0 {
for idx := range exceptions {
exceptions[idx].Rule = policyRule
docLink, err := resolveRuleDocLink(policyRule)
if err != nil {
return fmt.Errorf("resolve rule doc link failed: %w", err)
}
exceptions[idx].RuleDocLink = docLink
}
queryResult.Exceptions = append(queryResult.Exceptions, exceptions...)
return nil
}
for _, result := range results {
if result.Passed() {
queryResult.Successes += 1
continue
}
result.Rule = policyRule
ruleDocLink, err := resolveRuleDocLink(policyRule)
if err != nil {
return fmt.Errorf("resolve rule doc link failed: %w", err)
}
result.RuleDocLink = ruleDocLink
switch {
case policyRule.IsKind(policy.QueryKindWarn):
queryResult.Warnings = append(queryResult.Warnings, result)
case policyRule.IsKind(policy.QueryKindViolation, policy.QueryKindDeny):
queryResult.Failures = append(queryResult.Failures, result)
}
}
return nil
}
func (engine *RegoEngine) createRegoInstance(
parsedInput ast.Value,
query string,
) *rego.Rego {
opts := []func(*rego.Rego){
rego.ParsedInput(parsedInput),
rego.Query(query), // TODO: consider pre-compile query for perf
rego.Compiler(engine.compiler),
}
return rego.New(opts...)
}
func (engine *RegoEngine) executeOneQuery(
ctx context.Context,
parsedInput ast.Value,
query string,
) ([]result.Result, error) {
// NOTE: we expect the policy implementation is deterministic, which provides
// the same results for the same policy rules, input and query.
cacheKey := queryCacheKey{
compilerKey: engine.compilerKey,
parsedInput: parsedInput,
query: query,
}
if cachedResults, ok := engine.queryCache.get(cacheKey); ok {
return cachedResults, nil
}
results, err := engine.executeOneQuerySlow(ctx, parsedInput, query)
if err != nil {
return nil, err
}
engine.queryCache.set(cacheKey, results)
return results, nil
}
func (engine *RegoEngine) executeOneQuerySlow(
ctx context.Context,
parsedInput ast.Value,
query string,
) ([]result.Result, error) {
regoInstance := engine.createRegoInstance(parsedInput, query)
resultSet, err := regoInstance.Eval(ctx)
if err != nil {
return nil, err
}
var rv []result.Result
for _, evalResult := range resultSet {
for _, expression := range evalResult.Expressions {
loadedResults, err := result.FromRegoExpression(query, expression)
if err != nil {
return nil, fmt.Errorf("failed to load result: %w", err)
}
rv = append(rv, loadedResults...)
}
}
return rv, nil
}