internal/acs/testserver/testserver.go (167 lines of code) (raw):

// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package testserver provides helper methods for functional testing // against test ACS server. It provides facilities to make server send messages // to an agent and intercept messages from agent. // Note that this package is not meant for production code and can only be used // for testing. package testserver import ( "context" "errors" "fmt" "io" "net" "net/http" "net/http/httptest" "net/url" "os" "sync" acpb "github.com/GoogleCloudPlatform/agentcommunication_client/gapic/agentcommunicationpb" "github.com/GoogleCloudPlatform/galog" "github.com/GoogleCloudPlatform/google-guest-agent/internal/cfg" "google.golang.org/grpc" "google.golang.org/protobuf/proto" apb "google.golang.org/protobuf/types/known/anypb" ) const ( // rawToken is a dummy token used for testing. rawToken = "eyJhbGciOiIiLCJ0eXAiOiIifQ.eyJpc3MiOiIiLCJhdWQiOiIiLCJleHAiOjEyNTc4OTc1OTAsImlhdCI6MTI1Nzg5Mzk5MH0.P1kofb3I0Eaxd6xAWI0mLrfR2k48sIU9K_iWpwXQIX66Cd95dXtkGJ8JQ74KIWHK_HSYB7i7kSbukDl6VjDc1HrZlRtM8pVNbIv0lHyDe8FZgvW2w33964hk96I0M2NcSLyj6jO42yvWEs0VFJwoAuWtX9jXUqb7vlQf-ElmUXbx5jsKvMqjS6KtT44wQzUg9MjsOTfU9AEKhn-p0liNb-QJxG2Z0NzGI6dCfKchd-mXgpnn0r_2OAZ0aCICNu50ye74hfPCkEpTK5w4PWDoLNhWhJabBSoM4umct49G3nZ5jO1Auh50QaprskS_c82ZzgttNvNzv3NShHAAODCI8w" ) // acsImplementation struct holds all messages and address ACS server is running on. type acsImplementation struct { addr string // mu protects [agentSentMsgs] concurrent read/writes. mu sync.Mutex // agentSentMsgs are the messages sent by agent using [agentcomm.Send()]. // Essentially this represents the messages for the server. agentSentMsgs []*acpb.StreamAgentMessagesRequest // toSend is the channel holding messages temporarily that are sent down to // to an agent. Messages are received by agent on [agentcomm.Watch()]. toSend chan *acpb.StreamAgentMessagesResponse // recvErr captures errors received on the stream to report it back to the caller. recvErr chan error } // Server is a test server instance for ACS and MDS (required component by ACS). type Server struct { acshandler *acsImplementation // mds is the test MDS HTTP server instance. mds *httptest.Server // acs is the test ACS grpc server instance. acs *grpc.Server } // NewTestServer returns a new [Server] instance. func NewTestServer(addr string) *Server { return &Server{ acshandler: &acsImplementation{ addr: addr, toSend: make(chan *acpb.StreamAgentMessagesResponse), recvErr: make(chan error), }, } } func (s *acsImplementation) add(msg *acpb.StreamAgentMessagesRequest) { s.mu.Lock() defer s.mu.Unlock() s.agentSentMsgs = append(s.agentSentMsgs, msg) } func (s *acsImplementation) SendAgentMessage(context.Context, *acpb.SendAgentMessageRequest) (*acpb.SendAgentMessageResponse, error) { return nil, nil } // StreamAgentMessages implements the test ACS communication RPC. func (s *acsImplementation) StreamAgentMessages(stream acpb.AgentCommunication_StreamAgentMessagesServer) error { closed := make(chan struct{}) defer close(closed) go func() { for { rec, err := stream.Recv() select { case <-closed: return default: } if err != nil { if errors.Is(err, io.EOF) { s.recvErr <- nil return } s.recvErr <- err return } switch rec.GetType().(type) { case *acpb.StreamAgentMessagesRequest_MessageResponse: // Ignore ack's for test messages generated to send on Watch(). continue case *acpb.StreamAgentMessagesRequest_MessageBody: // Collect all messages sent by agent. s.add(rec) } // Acks as if service will ack on receiving msg from agent Send(). if err := stream.Send(&acpb.StreamAgentMessagesResponse{MessageId: rec.GetMessageId(), Type: &acpb.StreamAgentMessagesResponse_MessageResponse{}}); err != nil { s.recvErr <- err return } } }() for { select { case msg := <-s.toSend: if err := stream.Send(msg); err != nil { galog.Errorf("[TestACSServer] error sending message [%+v]: %v", msg, err) return err } case err := <-s.recvErr: galog.Errorf("[TestACSServer] received error on error stream: %v", err) return err } } } // startMDS starts the test MDS server and sets the environment variable // [GCE_METADATA_HOST] with its address. This address is used by the // [Metadata library] for MDS calls. [ACS Client library] depends on this // metadata library. // // [Metadata library]: https://cloud.google.com/go/compute/metadata // [ACS Client library]: https://github.com/GoogleCloudPlatform/agentcommunication_client func (s *Server) startMDS() error { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/computeMetadata/v1/instance/zone": fmt.Fprint(w, "test-zone") case "/computeMetadata/v1/project/numeric-project-id": fmt.Fprint(w, "test-project") case "/computeMetadata/v1/instance/id": fmt.Fprint(w, "test-instance") case "/computeMetadata/v1/instance/service-accounts/default/identity": fmt.Fprint(w, rawToken) } })) url, err := url.Parse(ts.URL) if err != nil { return fmt.Errorf("url.Parse(%s) failed with error: %w", ts.URL, err) } addr := url.Host if err := os.Setenv("GCE_METADATA_HOST", addr); err != nil { return fmt.Errorf("os.Setenv(GCE_METADATA_HOST, %s) failed with error: %w", addr, err) } s.mds = ts return nil } // startACSServer starts the ACS grpc server on a UDS address configured and // sets the environment variable [GCE_ACS_HOST] with its UDS address. // Guest Agent when creating a new ACS connection looks if this environment // variable is set otherwise uses the default prod server. func (s *Server) startACSServer() error { addr := s.acshandler.addr lis, err := net.Listen("unix", addr) if err != nil { return fmt.Errorf("unable to start listener on %s: %w", addr, err) } srv := grpc.NewServer() acpb.RegisterAgentCommunicationServer(srv, s.acshandler) go func() { if err := srv.Serve(lis); err != nil { galog.Debugf("Server stop serving on %q: %v", s.acshandler.addr, err) } }() cfg.Retrieve().ACS.Host = addr s.acs = srv return nil } // Start starts the test server. func (s *Server) Start() error { if err := s.startMDS(); err != nil { return fmt.Errorf("unable to start test MDS: %w", err) } return s.startACSServer() } // CleanUp unsets all environment variable set during setup and stops all // running servers. func (s *Server) CleanUp() { os.Unsetenv("GCE_METADATA_HOST") cfg.Retrieve().ACS.Host = "" s.mds.Close() s.acs.Stop() } // SendToAgent sends this message to agent, agent receives them on [acs.Watch()]. // Note that labels must include the [message_type] field like for e.g. // [agent_controlplane.GetOSInfo]. See ACS [handler] for all known message types. func (s *Server) SendToAgent(msg proto.Message, labels map[string]string) error { msgBytes, err := proto.Marshal(msg) if err != nil { return fmt.Errorf("proto.Marshal(%+v) failed with error: %v", msg, err) } body := &acpb.MessageBody{Body: &apb.Any{Value: msgBytes, TypeUrl: string(proto.MessageName(msg))}, Labels: labels} pb := &acpb.StreamAgentMessagesResponse{MessageId: "test-message-id", Type: &acpb.StreamAgentMessagesResponse_MessageBody{MessageBody: body}} s.acshandler.toSend <- pb return nil } // AgentSentMessages returns the messages sent by agent by calling [acs.Send()] // for server. func (s *Server) AgentSentMessages() []*apb.Any { s.acshandler.mu.Lock() defer s.acshandler.mu.Unlock() var msgs []*apb.Any for _, msg := range s.acshandler.agentSentMsgs { msgs = append(msgs, msg.GetMessageBody().GetBody()) } return msgs }