pkg/client/mock/stub_serverV2.go (292 lines of code) (raw):
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.
package mock
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"sync"
"github.com/gofrs/uuid/v5"
"google.golang.org/grpc"
"github.com/elastic/elastic-agent-client/v7/pkg/client/chunk"
"github.com/elastic/elastic-agent-client/v7/pkg/proto"
)
// StubServerCheckinV2 is the checkin function for the V2 controller
type StubServerCheckinV2 func(observed *proto.CheckinObserved) *proto.CheckinExpected
// StubServerArtifactFetch is the artifact fetch function for the V2 controller
type StubServerArtifactFetch func(request *proto.ArtifactFetchRequest, server proto.ElasticAgentArtifact_FetchServer) error
// StubServerLog is the logging function for the V2 controller
type StubServerLog func(fetch *proto.LogMessageRequest) (*proto.LogMessageResponse, error)
// StubServerStore is the interface that mocks the server artifact store
type StubServerStore interface {
BeginTx(request *proto.StoreBeginTxRequest) (*proto.StoreBeginTxResponse, error)
GetKey(request *proto.StoreGetKeyRequest) (*proto.StoreGetKeyResponse, error)
SetKey(request *proto.StoreSetKeyRequest) (*proto.StoreSetKeyResponse, error)
DeleteKey(request *proto.StoreDeleteKeyRequest) (*proto.StoreDeleteKeyResponse, error)
CommitTx(request *proto.StoreCommitTxRequest) (*proto.StoreCommitTxResponse, error)
DiscardTx(request *proto.StoreDiscardTxRequest) (*proto.StoreDiscardTxResponse, error)
}
// StubServerV2 is the mocked server implementation for the V2 controller
type StubServerV2 struct {
proto.UnimplementedElasticAgentServer
proto.UnimplementedElasticAgentStoreServer
proto.UnimplementedElasticAgentArtifactServer
proto.UnimplementedElasticAgentLogServer
// LocalRPC is unix socket or windows named pipe name
//
// Given the LocalRPC value "elastic_agent" the getRPCPath creates the platform specific path
// Typically something like \\.\pipe\elastic_agent for windows
// and /var/run/123456345/elastic_agent.sock for unix
LocalRPC string
// Port for TCP RPC only
Port int
CheckinImpl StubServerCheckin
ActionImpl StubServerAction
CheckinV2Impl StubServerCheckinV2
StoreImpl StubServerStore
ArtifactFetchImpl StubServerArtifactFetch
LogImpl StubServerLog
server *grpc.Server
ActionsChan chan *PerformAction
SentActions map[string]*PerformAction
// target for gRPC
target string
}
// listen over tcp or local unix socket or named windows pipe
func (s *StubServerV2) listen(opt ...grpc.ServerOption) (lis net.Listener, cleanup func() error, err error) {
cleanup = func() error { return nil }
if s.LocalRPC == "" {
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Port))
if err != nil {
return nil, nil, err
}
s.Port = lis.Addr().(*net.TCPAddr).Port
s.target = lis.Addr().String()
return lis, cleanup, nil
}
if runtime.GOOS == "windows" {
s.target = fmt.Sprintf("\\\\.\\pipe\\%s", s.LocalRPC)
lis, err = newNPipeListener(s.target, "")
} else {
socketDir := filepath.Join(os.TempDir(), randomString(3))
err = os.MkdirAll(socketDir, 0750)
if err != nil {
return nil, nil, err
}
cleanup = func() error {
return os.RemoveAll(socketDir)
}
// Cleanup in case if transport.Listen fails
defer func() {
if err != nil {
_ = cleanup()
cleanup = nil
}
}()
rpcPath := fmt.Sprintf("%s/%s.sock", socketDir, s.LocalRPC)
s.target = fmt.Sprintf("unix://%s", rpcPath)
lis, err = net.Listen("unix", rpcPath)
}
return lis, cleanup, err
}
func randomString(length int) string {
r := make([]byte, length)
_, err := rand.Read(r)
if err != nil {
panic(err)
}
return hex.EncodeToString(r)
}
// GetTarget returns full target for gRPC with prefix for local transport, depending on the current platform and the s.LocalRPC value
func (s *StubServerV2) GetTarget() string {
return s.target
}
// Start the mock server
func (s *StubServerV2) Start(opt ...grpc.ServerOption) error {
lis, cleanup, err := s.listen()
if err != nil {
return err
}
srv := grpc.NewServer(opt...)
s.server = srv
proto.RegisterElasticAgentServer(s.server, s)
proto.RegisterElasticAgentStoreServer(s.server, s)
proto.RegisterElasticAgentArtifactServer(s.server, s)
proto.RegisterElasticAgentLogServer(s.server, s)
go func() {
srv.Serve(lis)
defer cleanup()
}()
return nil
}
// Stop the mock server
func (s *StubServerV2) Stop() {
if s.server != nil {
s.server.Stop()
s.server = nil
}
}
// Checkin is the checkin implementation for the mock server
func (s *StubServerV2) Checkin(server proto.ElasticAgent_CheckinServer) error {
for {
checkin, err := server.Recv()
if err != nil {
return err
}
resp := s.CheckinImpl(checkin)
if resp == nil {
// close connection to client
return nil
}
err = server.Send(resp)
if err != nil {
return err
}
}
}
// CheckinV2 is the V2 checkin implementation for the mock server
func (s *StubServerV2) CheckinV2(server proto.ElasticAgent_CheckinV2Server) error {
for {
checkin, err := chunk.RecvObserved(server)
if err != nil {
return err
}
resp := s.CheckinV2Impl(checkin)
if resp == nil {
// close connection to client
return nil
}
err = server.Send(resp)
if err != nil {
return err
}
}
}
// Actions is the action implementation for the mock V2 server
func (s *StubServerV2) Actions(server proto.ElasticAgent_ActionsServer) error {
var m sync.Mutex
done := make(chan bool)
go func() {
for {
select {
case <-done:
return
case act := <-s.ActionsChan:
id := uuid.Must(uuid.NewV4())
m.Lock()
s.SentActions[id.String()] = act
m.Unlock()
err := server.Send(&proto.ActionRequest{
Type: act.Type,
Id: id.String(),
Name: act.Name,
Params: act.Params,
UnitId: act.UnitID,
UnitType: act.UnitType,
Level: act.Level,
})
if err != nil {
panic(err)
}
}
}
}()
defer close(done)
for {
response, err := server.Recv()
if err != nil {
return err
}
err = s.ActionImpl(response)
if err != nil {
// close connection to client
return nil
}
m.Lock()
action, ok := s.SentActions[response.Id]
if !ok {
// nothing to do, unknown action
m.Unlock()
continue
}
delete(s.SentActions, response.Id)
m.Unlock()
var result map[string]interface{}
if response.Result != nil {
err = json.Unmarshal(response.Result, &result)
if err != nil {
return err
}
}
if action.Type == proto.ActionRequest_CUSTOM {
if response.Status == proto.ActionResponse_FAILED {
error, ok := result["error"]
if ok {
err = fmt.Errorf("%s", error)
} else {
err = fmt.Errorf("unknown error")
}
action.Callback(nil, err)
} else {
action.Callback(result, nil)
}
} else if action.Type == proto.ActionRequest_DIAGNOSTICS {
if response.Status == proto.ActionResponse_FAILED {
error, ok := result["error"]
if ok {
err = fmt.Errorf("%s", error)
} else {
err = fmt.Errorf("unknown error")
}
action.DiagCallback(nil, err)
} else {
action.DiagCallback(response.Diagnostic, nil)
}
} else {
panic("unknown action type")
}
}
}
// PerformAction is the implementation for the V2 mock server
func (s *StubServerV2) PerformAction(unitID string, unitType proto.UnitType, name string, params map[string]interface{}) (map[string]interface{}, error) {
paramBytes, err := json.Marshal(params)
if err != nil {
return nil, err
}
resCh := make(chan actionResultCh)
s.ActionsChan <- &PerformAction{
UnitID: unitID,
UnitType: unitType,
Name: name,
Params: paramBytes,
Callback: func(m map[string]interface{}, err error) {
resCh <- actionResultCh{
Result: m,
Err: err,
}
},
}
res := <-resCh
return res.Result, res.Err
}
// PerformDiagnostic is the implementation for the V2 mock server
func (s *StubServerV2) PerformDiagnostic(unitID string, unitType proto.UnitType, level proto.ActionRequest_Level, params []byte) ([]*proto.ActionDiagnosticUnitResult, error) {
resCh := make(chan actionResultCh)
s.ActionsChan <- &PerformAction{
Type: proto.ActionRequest_DIAGNOSTICS,
UnitID: unitID,
UnitType: unitType,
Level: level,
Params: params,
DiagCallback: func(diag []*proto.ActionDiagnosticUnitResult, err error) {
resCh <- actionResultCh{
Diag: diag,
Err: err,
}
},
}
res := <-resCh
return res.Diag, res.Err
}
// BeginTx implmenentation for the V2 stub server
func (s *StubServerV2) BeginTx(_ context.Context, request *proto.StoreBeginTxRequest) (*proto.StoreBeginTxResponse, error) {
return s.StoreImpl.BeginTx(request)
}
// GetKey implmenentation for the V2 stub server
func (s *StubServerV2) GetKey(_ context.Context, request *proto.StoreGetKeyRequest) (*proto.StoreGetKeyResponse, error) {
return s.StoreImpl.GetKey(request)
}
// SetKey implmenentation for the V2 stub server
func (s *StubServerV2) SetKey(_ context.Context, request *proto.StoreSetKeyRequest) (*proto.StoreSetKeyResponse, error) {
return s.StoreImpl.SetKey(request)
}
// DeleteKey implmenentation for the V2 stub server
func (s *StubServerV2) DeleteKey(_ context.Context, request *proto.StoreDeleteKeyRequest) (*proto.StoreDeleteKeyResponse, error) {
return s.StoreImpl.DeleteKey(request)
}
// CommitTx implmenentation for the V2 stub server
func (s *StubServerV2) CommitTx(_ context.Context, request *proto.StoreCommitTxRequest) (*proto.StoreCommitTxResponse, error) {
return s.StoreImpl.CommitTx(request)
}
// DiscardTx implmenentation for the V2 stub server
func (s *StubServerV2) DiscardTx(_ context.Context, request *proto.StoreDiscardTxRequest) (*proto.StoreDiscardTxResponse, error) {
return s.StoreImpl.DiscardTx(request)
}
// Fetch implmenentation for the V2 stub server
func (s *StubServerV2) Fetch(request *proto.ArtifactFetchRequest, server proto.ElasticAgentArtifact_FetchServer) error {
return s.ArtifactFetchImpl(request, server)
}
// Log implmenentation for the V2 stub server
func (s *StubServerV2) Log(_ context.Context, request *proto.LogMessageRequest) (*proto.LogMessageResponse, error) {
return s.LogImpl(request)
}