internal/provider/provider.go (174 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-License-Identifier: MPL-2.0
package provider
import (
"context"
"fmt"
listvalidators "github.com/hashicorp/terraform-plugin-framework-validators/listvalidator"
"github.com/hashicorp/terraform-plugin-framework/schema/validator"
"io"
"net/http"
"os"
"regexp"
"strconv"
"sync"
"time"
"github.com/hashicorp/terraform-plugin-framework/datasource"
"github.com/hashicorp/terraform-plugin-framework/function"
"github.com/hashicorp/terraform-plugin-framework/provider"
"github.com/hashicorp/terraform-plugin-framework/provider/schema"
"github.com/hashicorp/terraform-plugin-framework/resource"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-plugin-framework/types/basetypes"
)
// Ensure ModuleTelemetryProvider satisfies various provider interfaces.
var _ provider.Provider = &ModuleTelemetryProvider{}
// ModuleTelemetryProvider defines the provider implementation.
type ModuleTelemetryProvider struct {
// version is set to the provider version on release, "dev" when the
// provider is built and ran locally, and "test" when running acceptance
// testing.
version string
}
// ModuleTelemetryProviderModel describes the provider data model.
type ModuleTelemetryProviderModel struct {
Endpoint types.String `tfsdk:"endpoint"`
Enabled types.Bool `tfsdk:"enabled"`
ModuleSourceRegex types.List `tfsdk:"module_source_regex"`
}
type providerConfig struct {
endpointFunc func() string
enabled bool
defaultEndpoint bool
moduleSourceRegex []*regexp.Regexp
}
func (p *ModuleTelemetryProvider) Metadata(ctx context.Context, req provider.MetadataRequest, resp *provider.MetadataResponse) {
resp.TypeName = "modtm"
resp.Version = p.version
}
func (p *ModuleTelemetryProvider) Schema(ctx context.Context, req provider.SchemaRequest, resp *provider.SchemaResponse) {
resp.Schema = schema.Schema{
Attributes: map[string]schema.Attribute{
"endpoint": schema.StringAttribute{
MarkdownDescription: "Telemetry endpoint to send data to.",
Optional: true,
},
"enabled": schema.BoolAttribute{
MarkdownDescription: "Sending telemetry or not, set this argument to `false` would turn telemetry off. Defaults to `true`.",
Optional: true,
},
"module_source_regex": schema.ListAttribute{
ElementType: types.StringType,
Optional: true,
MarkdownDescription: "List of regex as allow list for module source. Only module source that match one of the regex will be collected.",
Validators: []validator.List{
listvalidators.SizeAtLeast(1),
listvalidators.ValueStringsAre(&MustBeValidRegex{}),
},
},
},
}
}
func (p *ModuleTelemetryProvider) Configure(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) {
var data ModuleTelemetryProviderModel
resp.Diagnostics.Append(req.Config.Get(ctx, &data)...)
if resp.Diagnostics.HasError() {
return
}
enabled := true
if !data.Enabled.IsNull() {
enabled = data.Enabled.ValueBool()
}
var once sync.Once
endpoint := ""
endpointEnv := os.Getenv("MODTM_ENDPOINT")
c := providerConfig{
endpointFunc: func() string {
once.Do(func() {
if !data.Endpoint.IsNull() {
endpoint = readEndpointFromProviderBlock(data)
traceLog(ctx, fmt.Sprintf("Load provider's endpoint from provider block: %s", endpoint))
} else if endpointEnv != "" {
endpoint = endpointEnv
traceLog(ctx, fmt.Sprintf("Load provider's endpoint from environment variable: %s", endpoint))
} else {
e, err := readEndpointFromBlob()
if err != nil {
endpoint = ""
traceLog(ctx, "Failed to load provider's endpoint from default blob storage")
return
}
endpoint = e
traceLog(ctx, fmt.Sprintf("Load provider's endpoint from default blob storage: %s", endpoint))
}
})
return endpoint
},
enabled: enabled,
}
if !data.ModuleSourceRegex.IsNull() {
for _, value := range data.ModuleSourceRegex.Elements() {
c.moduleSourceRegex = append(c.moduleSourceRegex, regexp.MustCompile(value.(basetypes.StringValue).ValueString()))
}
}
if len(c.moduleSourceRegex) == 0 {
c.moduleSourceRegex = append(c.moduleSourceRegex, regexp.MustCompile(".*"))
}
c.defaultEndpoint = data.Endpoint.IsNull() && endpointEnv == ""
resp.DataSourceData = c
resp.ResourceData = resp.DataSourceData
}
func readEndpointFromProviderBlock(data ModuleTelemetryProviderModel) string {
e, err := strconv.Unquote(data.Endpoint.String())
if err != nil {
return data.Endpoint.String()
}
return e
}
func (p *ModuleTelemetryProvider) Resources(ctx context.Context) []func() resource.Resource {
return []func() resource.Resource{
NewTelemetryResource,
}
}
func (p *ModuleTelemetryProvider) DataSources(ctx context.Context) []func() datasource.DataSource {
return []func() datasource.DataSource{
NewModuleSourceDataSource,
}
}
func (p *ModuleTelemetryProvider) Functions(ctx context.Context) []func() function.Function {
return []func() function.Function{
NewModuleSourceFunction,
NewModuleVersionFunction,
}
}
func New(version string) func() provider.Provider {
return func() provider.Provider {
return &ModuleTelemetryProvider{
version: version,
}
}
}
var endpointBlobUrl = "https://avmtftelemetrysvc.blob.core.windows.net/blob/endpoint"
func readEndpointFromBlob() (string, error) {
c := make(chan int)
errChan := make(chan error)
var endpoint string
var returnError error
go func() {
resp, err := http.Get(endpointBlobUrl) // #nosec G107
if err != nil {
errChan <- err
return
}
defer func() {
_ = resp.Body.Close()
}()
bytes, err := io.ReadAll(resp.Body)
if err != nil {
errChan <- err
return
}
endpoint = string(bytes)
c <- 1
}()
select {
case <-c:
return endpoint, returnError
case err := <-errChan:
return "", err
case <-time.After(5 * time.Second):
return "", fmt.Errorf("timeout on reading default endpoint")
}
}