cli/azd/internal/grpcserver/event_service.go (352 lines of code) (raw):

// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. package grpcserver import ( "context" "errors" "fmt" "io" "log" "sync" "github.com/azure/azure-dev/cli/azd/pkg/azdext" "github.com/azure/azure-dev/cli/azd/pkg/environment" "github.com/azure/azure-dev/cli/azd/pkg/ext" "github.com/azure/azure-dev/cli/azd/pkg/extensions" "github.com/azure/azure-dev/cli/azd/pkg/input" "github.com/azure/azure-dev/cli/azd/pkg/lazy" "github.com/azure/azure-dev/cli/azd/pkg/project" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) // noEnvResolver is a resolver that always returns an empty string. // This is used when an environment is not available to resolve environment variables referenced in project config. var noEnvResolver = func(name string) string { return "" } // eventService implements azdext.EventServiceServer. type eventService struct { azdext.UnimplementedEventServiceServer extensionManager *extensions.Manager console input.Console lazyProject *lazy.Lazy[*project.ProjectConfig] lazyEnv *lazy.Lazy[*environment.Environment] projectEvents sync.Map // key: string, value: chan *azdext.ProjectHandlerStatus serviceEvents sync.Map // key: string, value: chan *azdext.ServiceHandlerStatus } func NewEventService( extensionManager *extensions.Manager, lazyProject *lazy.Lazy[*project.ProjectConfig], lazyEnv *lazy.Lazy[*environment.Environment], console input.Console, ) azdext.EventServiceServer { return &eventService{ extensionManager: extensionManager, lazyProject: lazyProject, lazyEnv: lazyEnv, console: console, } } // EventStream handles bidirectional streaming. func (s *eventService) EventStream(stream grpc.BidiStreamingServer[azdext.EventMessage, azdext.EventMessage]) error { ctx := stream.Context() extensionClaims, err := GetExtensionClaims(ctx) if err != nil { return fmt.Errorf("failed to get extension claims: %w", err) } options := extensions.LookupOptions{ Id: extensionClaims.Subject, } extension, err := s.extensionManager.GetInstalled(options) if err != nil { return status.Errorf(codes.FailedPrecondition, "failed to get extension: %s", err.Error()) } if !extension.HasCapability(extensions.LifecycleEventsCapability) { return status.Errorf(codes.PermissionDenied, "extension does not support lifecycle events") } for { select { case <-ctx.Done(): log.Println("Context cancelled by caller, exiting EventStream") return nil default: msg, err := stream.Recv() if errors.Is(err, io.EOF) { log.Println("Stream closed by server") return nil } if err != nil { return err } switch msg.MessageType.(type) { case *azdext.EventMessage_SubscribeProjectEvent: subscribeMsg := msg.GetSubscribeProjectEvent() if err := s.handleSubscribeProjectEvent(extension, subscribeMsg, stream); err != nil { log.Println(err.Error()) } case *azdext.EventMessage_SubscribeServiceEvent: subscribeMsg := msg.GetSubscribeServiceEvent() if err := s.handleSubscribeServiceEvent(extension, subscribeMsg, stream); err != nil { log.Println(err.Error()) } case *azdext.EventMessage_ProjectHandlerStatus: statusMsg := msg.GetProjectHandlerStatus() s.handleProjectHandlerStatus(extension, statusMsg) case *azdext.EventMessage_ServiceHandlerStatus: statusMsg := msg.GetServiceHandlerStatus() s.handleServiceHandlerStatus(extension, statusMsg) case *azdext.EventMessage_ExtensionReadyEvent: s.handleReadyEvent(extension) } } } } func (s *eventService) handleReadyEvent(extension *extensions.Extension) { extension.Initialize() } // ----- Project Event Handlers ----- func (s *eventService) handleSubscribeProjectEvent( extension *extensions.Extension, subscribeMsg *azdext.SubscribeProjectEvent, stream grpc.BidiStreamingServer[azdext.EventMessage, azdext.EventMessage], ) error { projectConfig, err := s.lazyProject.GetValue() if err != nil { return err } for i := 0; i < len(subscribeMsg.EventNames); i++ { eventName := subscribeMsg.EventNames[i] fullEventName := fmt.Sprintf("%s.%s", extension.Id, eventName) // Create a channel for this event. s.projectEvents.Store(fullEventName, make(chan *azdext.ProjectHandlerStatus, 1)) evt := ext.Event(eventName) handler := s.createProjectEventHandler(stream, extension, eventName) if err := projectConfig.AddHandler(evt, handler); err != nil { return fmt.Errorf("failed to add handler for event %s: %w", eventName, err) } } return nil } func (s *eventService) createProjectEventHandler( stream grpc.BidiStreamingServer[azdext.EventMessage, azdext.EventMessage], extension *extensions.Extension, eventName string, ) ext.EventHandlerFn[project.ProjectLifecycleEventArgs] { return func(ctx context.Context, args project.ProjectLifecycleEventArgs) error { previewTitle := fmt.Sprintf("%s (%s)", extension.DisplayName, eventName) defer s.syncExtensionOutput(ctx, extension, previewTitle)() // Send the invoke message. if err := s.sendProjectInvokeMessage(stream, eventName, args.Project); err != nil { return err } // Wait for status response. return s.waitForProjectStatus(ctx, eventName, extension) } } func (s *eventService) sendProjectInvokeMessage( stream grpc.BidiStreamingServer[azdext.EventMessage, azdext.EventMessage], eventName string, proj *project.ProjectConfig, ) error { return stream.Send(&azdext.EventMessage{ MessageType: &azdext.EventMessage_InvokeProjectHandler{ InvokeProjectHandler: &azdext.InvokeProjectHandler{ EventName: eventName, Project: s.createProjectConfig(proj), }, }, }) } func (s *eventService) waitForProjectStatus(ctx context.Context, eventName string, extension *extensions.Extension) error { extensionEventName := fmt.Sprintf("%s.%s", extension.Id, eventName) val, ok := s.projectEvents.Load(extensionEventName) if !ok { return fmt.Errorf("no status channel for event: %s", eventName) } ch := val.(chan *azdext.ProjectHandlerStatus) var status *azdext.ProjectHandlerStatus select { case <-ctx.Done(): return ctx.Err() case status = <-ch: // Clean up after receiving status. s.projectEvents.Delete(extensionEventName) } if status.Status == "failed" { return fmt.Errorf("extension %s project hook %s failed: %s", extension.Id, eventName, status.Message) } return nil } // ----- Service Event Handlers ----- func (s *eventService) handleSubscribeServiceEvent( extension *extensions.Extension, subscribeMsg *azdext.SubscribeServiceEvent, stream grpc.BidiStreamingServer[azdext.EventMessage, azdext.EventMessage], ) error { projectConfig, err := s.lazyProject.GetValue() if err != nil { return err } for i := 0; i < len(subscribeMsg.EventNames); i++ { eventName := subscribeMsg.EventNames[i] evt := ext.Event(eventName) for serviceName, serviceConfig := range projectConfig.Services { if subscribeMsg.Language != "" && string(serviceConfig.Language) != subscribeMsg.Language { continue } if subscribeMsg.Host != "" && string(serviceConfig.Host) != subscribeMsg.Host { continue } // Create a channel for this event. // fullEventName is used to uniquely identify the event for a specific service. fullEventName := fmt.Sprintf("%s.%s.%s", extension.Id, serviceName, eventName) s.serviceEvents.Store(fullEventName, make(chan *azdext.ServiceHandlerStatus, 1)) handler := s.createServiceEventHandler(stream, serviceConfig, extension, eventName) if err := serviceConfig.AddHandler(evt, handler); err != nil { return fmt.Errorf("failed to add handler for event %s: %w", eventName, err) } } } return nil } func (s *eventService) createServiceEventHandler( stream grpc.BidiStreamingServer[azdext.EventMessage, azdext.EventMessage], serviceConfig *project.ServiceConfig, extension *extensions.Extension, eventName string, ) ext.EventHandlerFn[project.ServiceLifecycleEventArgs] { fullEventName := fmt.Sprintf("%s.%s.%s", extension.Id, serviceConfig.Name, eventName) return func(ctx context.Context, args project.ServiceLifecycleEventArgs) error { previewTitle := fmt.Sprintf("%s (%s.%s)", extension.DisplayName, args.Service.Name, eventName) defer s.syncExtensionOutput(ctx, extension, previewTitle)() // Send the invoke message. if err := s.sendServiceInvokeMessage(stream, eventName, args.Project, args.Service); err != nil { return err } // Wait for status response. return s.waitForServiceStatus(ctx, fullEventName, extension) } } func (s *eventService) sendServiceInvokeMessage( stream grpc.BidiStreamingServer[azdext.EventMessage, azdext.EventMessage], eventName string, proj *project.ProjectConfig, svc *project.ServiceConfig, ) error { return stream.Send(&azdext.EventMessage{ MessageType: &azdext.EventMessage_InvokeServiceHandler{ InvokeServiceHandler: &azdext.InvokeServiceHandler{ EventName: eventName, Project: s.createProjectConfig(proj), Service: s.createServiceConfig(svc), }, }, }) } func (s *eventService) waitForServiceStatus( ctx context.Context, fullEventName string, extension *extensions.Extension, ) error { val, ok := s.serviceEvents.Load(fullEventName) if !ok { return fmt.Errorf("no status channel for event: %s", fullEventName) } ch := val.(chan *azdext.ServiceHandlerStatus) var status *azdext.ServiceHandlerStatus select { case <-ctx.Done(): return ctx.Err() case status = <-ch: // Clean up after receiving status. s.serviceEvents.Delete(fullEventName) } if status.Status == "failed" { return fmt.Errorf("extension %s service hook %s failed: %s", extension.Id, fullEventName, status.Message) } return nil } // ----- Dispatch Handlers ----- func (s *eventService) handleProjectHandlerStatus( extension *extensions.Extension, statusMessage *azdext.ProjectHandlerStatus, ) { fullEventName := fmt.Sprintf("%s.%s", extension.Id, statusMessage.EventName) if val, ok := s.projectEvents.Load(fullEventName); ok { ch := val.(chan *azdext.ProjectHandlerStatus) ch <- statusMessage } } func (s *eventService) handleServiceHandlerStatus( extension *extensions.Extension, statusMessage *azdext.ServiceHandlerStatus, ) { fullEventName := fmt.Sprintf("%s.%s.%s", extension.Id, statusMessage.ServiceName, statusMessage.EventName) if val, ok := s.serviceEvents.Load(fullEventName); ok { ch := val.(chan *azdext.ServiceHandlerStatus) ch <- statusMessage } } // createProjectConfig converts a project.ProjectConfig into the azdext.ProjectConfig wire format. func (s *eventService) createProjectConfig(proj *project.ProjectConfig) *azdext.ProjectConfig { resolver := noEnvResolver env, err := s.lazyEnv.GetValue() if err == nil && env != nil { resolver = env.Getenv } resourceGroupName, err := proj.ResourceGroupName.Envsubst(resolver) if err != nil { log.Printf("failed to envsubst resource group name: %v", err) } services := make(map[string]*azdext.ServiceConfig, len(proj.Services)) for i, svc := range proj.Services { services[i] = s.createServiceConfig(svc) } projectConfig := &azdext.ProjectConfig{ Name: proj.Name, ResourceGroupName: resourceGroupName, Path: proj.Path, Metadata: func() *azdext.ProjectMetadata { if proj.Metadata != nil { return &azdext.ProjectMetadata{Template: proj.Metadata.Template} } return nil }(), Infra: &azdext.InfraOptions{ Provider: string(proj.Infra.Provider), Path: proj.Infra.Path, Module: proj.Infra.Module, }, Services: services, } return projectConfig } // createServiceConfig converts a project.ServiceConfig into the azdext.ServiceConfig wire format. func (s *eventService) createServiceConfig(svc *project.ServiceConfig) *azdext.ServiceConfig { resolver := noEnvResolver env, err := s.lazyEnv.GetValue() if err == nil && env != nil { resolver = env.Getenv } resourceGroupName, err := svc.ResourceGroupName.Envsubst(resolver) if err != nil { log.Printf("failed to envsubst resource group name: %v", err) } resourceName, err := svc.ResourceName.Envsubst(resolver) if err != nil { log.Printf("failed to envsubst resource name: %v", err) } image, err := svc.Image.Envsubst(resolver) if err != nil { log.Printf("failed to envsubst image: %v", err) } return &azdext.ServiceConfig{ Name: svc.Name, ResourceGroupName: resourceGroupName, ResourceName: resourceName, ApiVersion: svc.ApiVersion, RelativePath: svc.RelativePath, Host: string(svc.Host), Language: string(svc.Language), OutputPath: svc.OutputPath, Image: image, } } // syncExtensionOutput displays the extension output in the preview experience. // defer the returned function to stop the previewer when the function exits. func (s *eventService) syncExtensionOutput( ctx context.Context, extension *extensions.Extension, previewTitle string, ) func() { // Display the extension output in the preview experience previewOptions := &input.ShowPreviewerOptions{ Prefix: " ", Title: previewTitle, MaxLineCount: 8, } // This gets the multi-writer used by the extension and adds the preview writer to it. // Any output from stdout on the extension will be shown in the preview window. extOut := extension.StdOut() previewWriter := s.console.ShowPreviewer(ctx, previewOptions) extOut.AddWriter(previewWriter) // Stop the previewer when the function exits. return func() { s.console.StopPreviewer(ctx, false) extOut.RemoveWriter(previewWriter) } }