cli/azd/pkg/extensions/manager.go (542 lines of code) (raw):

// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. package extensions import ( "context" "crypto/sha256" "crypto/sha512" "encoding/hex" "errors" "fmt" "hash" "io" "log" "net/http" "os" "path/filepath" "runtime" "slices" "sort" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Masterminds/semver/v3" "github.com/azure/azure-dev/cli/azd/pkg/alpha" "github.com/azure/azure-dev/cli/azd/pkg/config" "github.com/azure/azure-dev/cli/azd/pkg/osutil" "github.com/azure/azure-dev/cli/azd/pkg/rzip" ) const ( extensionRegistryUrl = "https://aka.ms/azd/extensions/registry" ) var ( ErrExtensionNotFound = errors.New("extension not found") ErrInstalledExtensionNotFound = errors.New("extension not found") ErrRegistryExtensionNotFound = errors.New("extension not found in registry") ErrExtensionInstalled = errors.New("extension already installed") FeatureExtensions = alpha.MustFeatureKey("extensions") ) // ListOptions is used to filter extensions by source and tags type ListOptions struct { // Source is used to specify the source of the extension to install Source string // Tags is used to specify the tags of the extension to install Tags []string } // FilterOptions is used to filter extensions by version and source type FilterOptions struct { // Version is used to specify the version of the extension to install Version string // Source is used to specify the source of the extension to install Source string } // LookupOptions is used to lookup extensions by id or namespace type LookupOptions struct { // Id is used to specify the id of the extension to install Id string // Namespace is used to specify the namespace of the extension to install Namespace string } type sourceFilterPredicate func(config *SourceConfig) bool type extensionFilterPredicate func(extension *ExtensionMetadata) bool // Manager is responsible for managing extensions type Manager struct { sourceManager *SourceManager sources []Source installed map[string]*Extension configManager config.UserConfigManager userConfig config.Config pipeline azruntime.Pipeline } // NewManager creates a new extension manager func NewManager( configManager config.UserConfigManager, sourceManager *SourceManager, transport policy.Transporter, ) (*Manager, error) { userConfig, err := configManager.Load() if err != nil { return nil, err } pipeline := azruntime.NewPipeline("azd-extensions", "1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{ Transport: transport, }) return &Manager{ userConfig: userConfig, configManager: configManager, sourceManager: sourceManager, pipeline: pipeline, }, nil } // ListInstalled retrieves a list of installed extensions func (m *Manager) ListInstalled() (map[string]*Extension, error) { var extensions map[string]*Extension if m.installed != nil { return m.installed, nil } ok, err := m.userConfig.GetSection(installedConfigKey, &extensions) if err != nil { return nil, fmt.Errorf("failed to get extensions section: %w", err) } if !ok || extensions == nil { extensions = map[string]*Extension{} } // Initialize the extensions since this are instantiated from JSON unmarshalling. for _, extension := range extensions { extension.init() } m.installed = extensions return m.installed, nil } // GetInstalled retrieves an installed extension by name func (m *Manager) GetInstalled(options LookupOptions) (*Extension, error) { extensions, err := m.ListInstalled() if err != nil { return nil, err } if options.Id != "" { extension, has := extensions[options.Id] if !has { return nil, fmt.Errorf("%s %w", options.Id, ErrInstalledExtensionNotFound) } return extension, nil } if options.Namespace != "" { for _, extension := range extensions { if strings.EqualFold(extension.Namespace, options.Namespace) { return extension, nil } } } return nil, ErrInstalledExtensionNotFound } // GetFromRegistry retrieves an extension from the registry by name func (m *Manager) GetFromRegistry( ctx context.Context, extensionId string, options *FilterOptions, ) (*ExtensionMetadata, error) { if options == nil { options = &FilterOptions{} } filterPredicate := func(config *SourceConfig) bool { if options.Source == "" { return true } return strings.EqualFold(config.Name, options.Source) } sources, err := m.getSources(ctx, filterPredicate) if err != nil { return nil, fmt.Errorf("failed getting extension sources: %w", err) } var match *ExtensionMetadata var sourceErr error for _, source := range sources { extension, err := source.GetExtension(ctx, extensionId) if err != nil { sourceErr = err } else if extension != nil { match = extension break } } if match != nil { return match, nil } if sourceErr != nil { return nil, fmt.Errorf("failed getting extension: %w", sourceErr) } return nil, fmt.Errorf("%s %w", extensionId, ErrRegistryExtensionNotFound) } func (m *Manager) ListFromRegistry(ctx context.Context, options *ListOptions) ([]*ExtensionMetadata, error) { allExtensions := []*ExtensionMetadata{} if options == nil { options = &ListOptions{} } var sourceFilterPredicate sourceFilterPredicate if options.Source != "" { sourceFilterPredicate = func(config *SourceConfig) bool { return strings.EqualFold(config.Name, options.Source) } } var extensionFilterPredicate extensionFilterPredicate if len(options.Tags) > 0 { // Find extensions that match all the incoming tags extensionFilterPredicate = func(extension *ExtensionMetadata) bool { match := false for _, optionTag := range options.Tags { match = slices.ContainsFunc(extension.Tags, func(extensionTag string) bool { return strings.EqualFold(optionTag, extensionTag) }) if !match { break } } return match } } sources, err := m.getSources(ctx, sourceFilterPredicate) if err != nil { return nil, fmt.Errorf("failed listing extensions: %w", err) } for _, source := range sources { filteredExtensions := []*ExtensionMetadata{} sourceExtensions, err := source.ListExtensions(ctx) if err != nil { return nil, fmt.Errorf("unable to list extension: %w", err) } for _, extension := range sourceExtensions { if extensionFilterPredicate == nil || extensionFilterPredicate(extension) { filteredExtensions = append(filteredExtensions, extension) } } // Sort by source, then repository path and finally name slices.SortFunc(filteredExtensions, func(a *ExtensionMetadata, b *ExtensionMetadata) int { if a.Source != b.Source { return strings.Compare(a.Source, b.Source) } return strings.Compare(a.Id, b.Id) }) allExtensions = append(allExtensions, filteredExtensions...) } return allExtensions, nil } // Install an extension by name and optional version // If no version is provided, the latest version is installed // Latest version is determined by the last element in the Versions slice func (m *Manager) Install(ctx context.Context, id string, options *FilterOptions) (*ExtensionVersion, error) { if options == nil { options = &FilterOptions{} } installed, err := m.GetInstalled(LookupOptions{Id: id}) if err == nil && installed != nil { return nil, fmt.Errorf("%s %w", id, ErrExtensionInstalled) } // Step 1: Find the extension by name extension, err := m.GetFromRegistry(ctx, id, options) if err != nil { return nil, err } // Step 2: Determine the version to install var selectedVersion *ExtensionVersion availableVersions := []*semver.Version{} availableVersionMap := map[*semver.Version]*ExtensionVersion{} // Create a map of available versions and sort them // This sorts the version from lowest to highest for _, extensionVersion := range extension.Versions { version, err := semver.NewVersion(extensionVersion.Version) if err != nil { return nil, fmt.Errorf("failed to parse version: %w", err) } availableVersionMap[version] = &extensionVersion availableVersions = append(availableVersions, version) } sort.Sort(semver.Collection(availableVersions)) if options.Version == "" || options.Version == "latest" { latestVersion := availableVersions[len(availableVersions)-1] selectedVersion = availableVersionMap[latestVersion] } else { // Find the best match for the version constraint constraint, err := semver.NewConstraint(options.Version) if err != nil { return nil, fmt.Errorf("failed to parse version constraint: %w", err) } var bestMatch *semver.Version for _, v := range availableVersions { // Find the highest version that satisfies the constraint if constraint.Check(v) { bestMatch = v } } if bestMatch == nil { return nil, fmt.Errorf( "no matching version found for extension: %s and constraint: %s", id, options.Version, ) } selectedVersion = availableVersionMap[bestMatch] } if selectedVersion == nil { return nil, fmt.Errorf("no compatible version found for extension: %s", id) } // Binaries are optional as long as dependencies are provided // This allows for extensions that are just extension packs if len(selectedVersion.Artifacts) == 0 && len(selectedVersion.Dependencies) == 0 { return nil, fmt.Errorf("no binaries or dependencies available for this version") } // Install dependencies if len(selectedVersion.Dependencies) > 0 { for _, dependency := range selectedVersion.Dependencies { dependencyInstallOptions := &FilterOptions{ Version: dependency.Version, Source: options.Source, } if _, err := m.Install(ctx, dependency.Id, dependencyInstallOptions); err != nil { if !errors.Is(err, ErrExtensionInstalled) { return nil, fmt.Errorf("failed to install dependency: %w", err) } } } } hasArtifact := len(selectedVersion.Artifacts) > 0 var relativeExtensionPath string var targetPath string // Install the artifacts if hasArtifact { // Step 3: Find the artifact for the current OS artifact, err := findArtifactForCurrentOS(selectedVersion) if err != nil { return nil, fmt.Errorf("failed to find artifact for current OS: %w", err) } // Step 4: Download the artifact to a temp location tempFilePath, err := m.downloadArtifact(ctx, artifact.URL) if err != nil { return nil, fmt.Errorf("failed to download artifact: %w", err) } // Clean up the temp file after all scenarios defer os.Remove(tempFilePath) // Step 5: Validate the checksum if provided if err := validateChecksum(tempFilePath, artifact.Checksum); err != nil { return nil, fmt.Errorf("checksum validation failed: %w", err) } userConfigDir, err := config.GetUserConfigDir() if err != nil { return nil, fmt.Errorf("failed to get user config directory: %w", err) } targetDir := filepath.Join(userConfigDir, "extensions", extension.Id) if err := os.MkdirAll(targetDir, os.ModePerm); err != nil { return nil, fmt.Errorf("failed to create target directory: %w", err) } // Step 6: Copy the artifact to the target directory // Check if artifact is a zip file, if so extract it to the target directory if strings.HasSuffix(tempFilePath, ".zip") { if err := rzip.ExtractToDirectory(tempFilePath, targetDir); err != nil { return nil, fmt.Errorf("failed to extract zip file: %w", err) } } else { targetPath = filepath.Join(targetDir, filepath.Base(tempFilePath)) if err := copyFile(tempFilePath, targetPath); err != nil { return nil, fmt.Errorf("failed to copy artifact to target location: %w", err) } } entryPoint := selectedVersion.EntryPoint if platformEntryPoint, has := artifact.AdditionalMetadata["entryPoint"]; has { entryPoint = fmt.Sprint(platformEntryPoint) } if entryPoint == "" { switch runtime.GOOS { case "windows": entryPoint = fmt.Sprintf("%s.exe", extension.Id) default: entryPoint = extension.Id } } targetPath := filepath.Join(targetDir, entryPoint) // Need to set the executable permission for the binary // This change is specifically required for Linux but will apply consistently across all platforms if err := os.Chmod(targetPath, osutil.PermissionExecutableFile); err != nil { return nil, fmt.Errorf("failed to set executable permission: %w", err) } relativeExtensionPath, err = filepath.Rel(userConfigDir, targetPath) if err != nil { return nil, fmt.Errorf("failed to get relative path: %w", err) } } // Step 7: Update the user config with the installed extension extensions, err := m.ListInstalled() if err != nil { return nil, fmt.Errorf("failed to list installed extensions: %w", err) } extensions[id] = &Extension{ Id: id, Capabilities: selectedVersion.Capabilities, Namespace: extension.Namespace, DisplayName: extension.DisplayName, Description: extension.Description, Version: selectedVersion.Version, Usage: selectedVersion.Usage, Path: relativeExtensionPath, Source: extension.Source, } if err := m.userConfig.Set(installedConfigKey, extensions); err != nil { return nil, fmt.Errorf("failed to set extensions section: %w", err) } if err := m.configManager.Save(m.userConfig); err != nil { return nil, fmt.Errorf("failed to save user config: %w", err) } log.Printf("Extension '%s' (version %s) installed successfully to %s\n", id, selectedVersion.Version, targetPath) return selectedVersion, nil } // Uninstall an extension by name func (m *Manager) Uninstall(id string) error { // Get the installed extension extension, err := m.GetInstalled(LookupOptions{Id: id}) if err != nil { return fmt.Errorf("failed to get installed extension: %w", err) } userConfigDir, err := config.GetUserConfigDir() if err != nil { return fmt.Errorf("failed to get user config directory: %w", err) } extensionDir := filepath.Join(userConfigDir, "extensions", extension.Id) if err := os.MkdirAll(extensionDir, os.ModePerm); err != nil { return fmt.Errorf("failed to create target directory: %w", err) } // Remove the extension artifacts when it exists _, err = os.Stat(extensionDir) if err == nil { if err := os.RemoveAll(extensionDir); err != nil { return fmt.Errorf("failed to remove extension: %w", err) } } // Update the user config extensions, err := m.ListInstalled() if err != nil { return fmt.Errorf("failed to list installed extensions: %w", err) } delete(extensions, id) if err := m.userConfig.Set(installedConfigKey, extensions); err != nil { return fmt.Errorf("failed to set extensions section: %w", err) } if err := m.configManager.Save(m.userConfig); err != nil { return fmt.Errorf("failed to save user config: %w", err) } log.Printf("Extension '%s' uninstalled successfully\n", id) return nil } // Upgrade upgrades the extension to the specified version // This is a convenience method that uninstalls the existing extension and installs the new version // If the version is not specified, the latest version is installed func (m *Manager) Upgrade(ctx context.Context, extensionId string, options *FilterOptions) (*ExtensionVersion, error) { if options == nil { options = &FilterOptions{} } if err := m.Uninstall(extensionId); err != nil { return nil, fmt.Errorf("failed to uninstall extension: %w", err) } extensionVersion, err := m.Install(ctx, extensionId, options) if err != nil { return nil, fmt.Errorf("failed to install extension: %w", err) } return extensionVersion, nil } // Helper function to find the artifact for the current OS func findArtifactForCurrentOS(version *ExtensionVersion) (*ExtensionArtifact, error) { if version.Artifacts == nil { return nil, fmt.Errorf("no binaries available for this version") } artifactVersions := []string{ fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH), runtime.GOOS, } for _, artifactVersion := range artifactVersions { artifact, exists := version.Artifacts[artifactVersion] if exists { if artifact.URL == "" { return nil, fmt.Errorf("artifact URL is missing for platform: %s", artifactVersion) } return &artifact, nil } } return nil, fmt.Errorf("no artifact available for platform: %s", artifactVersions) } // downloadFile downloads a file from the given URL and saves it to a temporary directory using the filename from the URL. func (m *Manager) downloadArtifact(ctx context.Context, artifactUrl string) (string, error) { if strings.HasPrefix(artifactUrl, "http://") || strings.HasPrefix(artifactUrl, "https://") { return m.downloadFromRemote(ctx, artifactUrl) } return m.copyFromLocalPath(artifactUrl) } // Handles downloading artifacts from HTTP/HTTPS URLs func (m *Manager) downloadFromRemote(ctx context.Context, artifactUrl string) (string, error) { req, err := azruntime.NewRequest(ctx, http.MethodGet, artifactUrl) if err != nil { return "", err } resp, err := m.pipeline.Do(req) if err != nil { return "", fmt.Errorf("failed to download file: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("failed to download file, status code: %d", resp.StatusCode) } filename := filepath.Base(artifactUrl) tempFilePath := filepath.Join(os.TempDir(), filename) tempFile, err := os.Create(tempFilePath) if err != nil { return "", fmt.Errorf("failed to create temporary file: %w", err) } defer tempFile.Close() _, err = io.Copy(tempFile, resp.Body) if err != nil { return "", fmt.Errorf("failed to write to temporary file: %w", err) } return tempFilePath, nil } // Handles copying artifacts from local or network file paths func (m *Manager) copyFromLocalPath(artifactPath string) (string, error) { // If the path is relative, resolve it against the userConfigDir if !filepath.IsAbs(artifactPath) { userConfigDir, err := config.GetUserConfigDir() if err != nil { return "", fmt.Errorf("failed to get user config directory: %w", err) } artifactPath = filepath.Join(userConfigDir, artifactPath) } if _, err := os.Stat(artifactPath); os.IsNotExist(err) { return "", fmt.Errorf("file does not exist at path: %s", artifactPath) } filename := filepath.Base(artifactPath) tempFilePath := filepath.Join(os.TempDir(), filename) if err := copyFile(artifactPath, tempFilePath); err != nil { return "", fmt.Errorf("failed to copy file to temporary location: %w", err) } return tempFilePath, nil } func (tm *Manager) getSources(ctx context.Context, filter sourceFilterPredicate) ([]Source, error) { if tm.sources != nil { return tm.sources, nil } configs, err := tm.sourceManager.List(ctx) if err != nil { return nil, fmt.Errorf("failed parsing extension sources: %w", err) } sources, err := tm.createSourcesFromConfig(ctx, configs, filter) if err != nil { return nil, fmt.Errorf("failed initializing extension sources: %w", err) } tm.sources = sources return tm.sources, nil } func (tm *Manager) createSourcesFromConfig( ctx context.Context, configs []*SourceConfig, filter sourceFilterPredicate, ) ([]Source, error) { sources := []Source{} for _, config := range configs { if filter != nil && !filter(config) { continue } source, err := tm.sourceManager.CreateSource(ctx, config) if err != nil { log.Printf("failed to create source: %s", err.Error()) continue } sources = append(sources, source) } return sources, nil } // validateChecksum validates the file at the given path against the expected checksum using the specified algorithm. func validateChecksum(filePath string, checksum ExtensionChecksum) error { // Check if checksum or required fields are nil if checksum.Algorithm == "" && checksum.Value == "" { log.Println("Checksum algorithm and value is missing, skipping checksum validation") return nil } var hashAlgo hash.Hash // Select the hashing algorithm based on the input switch checksum.Algorithm { case "sha256": hashAlgo = sha256.New() case "sha512": hashAlgo = sha512.New() default: return fmt.Errorf("unsupported checksum algorithm: %s", checksum.Algorithm) } // Open the file for reading file, err := os.Open(filePath) if err != nil { return fmt.Errorf("failed to open file for checksum validation: %w", err) } defer file.Close() // Compute the checksum if _, err := io.Copy(hashAlgo, file); err != nil { return fmt.Errorf("failed to compute checksum: %w", err) } // Convert the computed checksum to a hexadecimal string computedChecksum := hex.EncodeToString(hashAlgo.Sum(nil)) // Compare the computed checksum with the expected checksum if computedChecksum != checksum.Value { return fmt.Errorf("checksum mismatch: expected %s, got %s", checksum.Value, computedChecksum) } return nil } // Helper function to copy a file to the target directory func copyFile(src, dst string) error { input, err := os.Open(src) if err != nil { return fmt.Errorf("failed to open source file: %w", err) } defer input.Close() output, err := os.Create(dst) if err != nil { return fmt.Errorf("failed to create destination file: %w", err) } defer output.Close() _, err = io.Copy(output, input) if err != nil { return fmt.Errorf("failed to copy file: %w", err) } return nil }