internal/resources/providers/azurelib/inventory/resource_graph_provider.go (85 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package inventory
import (
"bytes"
"context"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resourcegraph/armresourcegraph"
"github.com/elastic/cloudbeat/internal/infra/clog"
"github.com/elastic/cloudbeat/internal/resources/utils/pointers"
)
type ResourceGraphAzureClientWrapper struct {
AssetQuery func(ctx context.Context, query armresourcegraph.QueryRequest, options *armresourcegraph.ClientResourcesOptions) (armresourcegraph.ClientResourcesResponse, error)
}
type ResourceGraphProviderAPI interface {
// ListAllAssetTypesByName List all content types of the given assets types
ListAllAssetTypesByName(ctx context.Context, assetsGroup string, assets []string) ([]AzureAsset, error)
}
type ResourceGraphProvider struct {
client *ResourceGraphAzureClientWrapper
log *clog.Logger
}
func NewResourceGraphProvider(log *clog.Logger, resourceGraphClient *armresourcegraph.Client) ResourceGraphProviderAPI {
// We wrap the client, so we can mock it in tests
wrapper := &ResourceGraphAzureClientWrapper{
AssetQuery: func(ctx context.Context, query armresourcegraph.QueryRequest, options *armresourcegraph.ClientResourcesOptions) (armresourcegraph.ClientResourcesResponse, error) {
return resourceGraphClient.Resources(ctx, query, options)
},
}
return &ResourceGraphProvider{
log: log,
client: wrapper,
}
}
func (p *ResourceGraphProvider) ListAllAssetTypesByName(ctx context.Context, assetGroup string, assets []string) ([]AzureAsset, error) {
p.log.Infof("Listing Azure assets: %v", assets)
query := armresourcegraph.QueryRequest{
Query: to.Ptr(generateQuery(assetGroup, assets)),
Options: &armresourcegraph.QueryRequestOptions{
ResultFormat: to.Ptr(armresourcegraph.ResultFormatObjectArray),
},
}
return p.runPaginatedQuery(ctx, query)
}
func generateQuery(assetGroup string, assets []string) string {
var query bytes.Buffer
query.WriteString(assetGroup)
for index, asset := range assets {
if index == 0 {
query.WriteString(" | where type == '")
} else {
query.WriteString(" or type == '")
}
query.WriteString(asset)
query.WriteString("'")
}
return query.String()
}
func (p *ResourceGraphProvider) runPaginatedQuery(ctx context.Context, query armresourcegraph.QueryRequest) ([]AzureAsset, error) {
var resourceAssets []AzureAsset
for {
response, err := p.client.AssetQuery(ctx, query, nil)
if err != nil {
return nil, err
}
for _, asset := range response.Data.([]any) {
structuredAsset := getAssetFromData(asset.(map[string]any))
resourceAssets = append(resourceAssets, structuredAsset)
}
if *response.ResultTruncated == armresourcegraph.ResultTruncatedFalse ||
pointers.Deref(response.SkipToken) == "" {
break
}
query.Options.SkipToken = response.SkipToken
}
return resourceAssets, nil
}
func readPager[T any](ctx context.Context, pager *runtime.Pager[T]) ([]T, error) {
var res []T
for pager.More() {
r, err := pager.NextPage(ctx)
if err != nil {
return nil, err
}
res = append(res, r)
}
return res, nil
}