internal/search/search_test_helpers.go (118 lines of code) (raw):
package search
import (
"context"
"fmt"
"io"
"net"
"strings"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
"github.com/sourcegraph/zoekt"
proto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1"
)
func StartBufconnServer(t *testing.T, files []*proto.FileMatch, stats *proto.Stats) (*grpc.ClientConn, string) {
lis := bufconn.Listen(1024 * 1024)
s := grpc.NewServer()
mock := &MockZoektGRPCServer{
files: files,
stats: stats,
}
proto.RegisterWebserverServiceServer(s, mock)
go func() {
_ = s.Serve(lis)
}()
dialer := func(context.Context, string) (net.Conn, error) {
return lis.Dial()
}
// Generate a unique name and add http:// prefix to make it compatible with parseEndpoint
endpoint := fmt.Sprintf("http://bufnet-%p", lis)
// Extract the host part to use as the target for the connection
hostPart := strings.TrimPrefix(endpoint, "http://")
hostPart = strings.Split(hostPart, "/")[0]
conn, err := NewTestGrpcClient(context.Background(), hostPart, dialer)
require.NoError(t, err)
return conn, endpoint
}
type MockZoektGRPCServer struct {
proto.UnimplementedWebserverServiceServer
files []*proto.FileMatch
stats *proto.Stats
fail bool
errMsg string
}
func (m *MockZoektGRPCServer) StreamSearch(req *proto.StreamSearchRequest, srv proto.WebserverService_StreamSearchServer) error {
if m.fail {
return fmt.Errorf("%s", m.errMsg)
}
return srv.Send(&proto.StreamSearchResponse{
ResponseChunk: &proto.SearchResponse{
Files: m.files,
Stats: m.stats,
},
})
}
func ToProtoFileMatches(matches []zoekt.FileMatch) []*proto.FileMatch {
out := make([]*proto.FileMatch, 0, len(matches))
for _, m := range matches {
out = append(out, &proto.FileMatch{
FileName: []byte(m.FileName),
Score: m.Score,
})
}
return out
}
func StartFailingBufconnServer(t *testing.T, errMsg string) (*grpc.ClientConn, string) {
lis := bufconn.Listen(1024 * 1024)
s := grpc.NewServer()
mock := &MockZoektGRPCServer{
fail: true,
errMsg: errMsg,
}
proto.RegisterWebserverServiceServer(s, mock)
go func() { _ = s.Serve(lis) }()
dialer := func(context.Context, string) (net.Conn, error) {
return lis.Dial()
}
// Add http:// prefix to make it compatible with parseEndpoint
endpoint := fmt.Sprintf("http://bufnet-%p", lis)
// Extract the host part to use as the target for the connection
hostPart := strings.TrimPrefix(endpoint, "http://")
hostPart = strings.Split(hostPart, "/")[0]
conn, err := NewTestGrpcClient(context.Background(), hostPart, dialer)
require.NoError(t, err)
return conn, endpoint
}
type mockStream struct {
responses []*proto.StreamSearchResponse
err error
index int
}
func (m *mockStream) Recv() (*proto.StreamSearchResponse, error) {
if m.err != nil {
return nil, m.err
}
if m.index >= len(m.responses) {
return nil, io.EOF
}
resp := m.responses[m.index]
m.index++
return resp, nil
}
func NewTestGrpcClient(ctx context.Context, target string, dialer func(context.Context, string) (net.Conn, error)) (*grpc.ClientConn, error) {
if !strings.HasPrefix(target, "passthrough:///") {
target = "passthrough:///" + target
}
return grpc.NewClient(target,
grpc.WithContextDialer(dialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
}
// GetHostFromEndpoint is a helper function to extract host part from endpoint URL
func GetHostFromEndpoint(endpoint string) string {
// Remove protocol prefix if present
cleanEndpoint := endpoint
if strings.HasPrefix(endpoint, "http://") {
cleanEndpoint = strings.TrimPrefix(endpoint, "http://")
} else if strings.HasPrefix(endpoint, "https://") {
cleanEndpoint = strings.TrimPrefix(endpoint, "https://")
}
// Extract host part (before any path)
hostPart := strings.Split(cleanEndpoint, "/")[0]
return hostPart
}