internal/search/search_test_helpers.go (118 lines of code) (raw):

package search import ( "context" "fmt" "io" "net" "strings" "testing" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" "github.com/sourcegraph/zoekt" proto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" ) func StartBufconnServer(t *testing.T, files []*proto.FileMatch, stats *proto.Stats) (*grpc.ClientConn, string) { lis := bufconn.Listen(1024 * 1024) s := grpc.NewServer() mock := &MockZoektGRPCServer{ files: files, stats: stats, } proto.RegisterWebserverServiceServer(s, mock) go func() { _ = s.Serve(lis) }() dialer := func(context.Context, string) (net.Conn, error) { return lis.Dial() } // Generate a unique name and add http:// prefix to make it compatible with parseEndpoint endpoint := fmt.Sprintf("http://bufnet-%p", lis) // Extract the host part to use as the target for the connection hostPart := strings.TrimPrefix(endpoint, "http://") hostPart = strings.Split(hostPart, "/")[0] conn, err := NewTestGrpcClient(context.Background(), hostPart, dialer) require.NoError(t, err) return conn, endpoint } type MockZoektGRPCServer struct { proto.UnimplementedWebserverServiceServer files []*proto.FileMatch stats *proto.Stats fail bool errMsg string } func (m *MockZoektGRPCServer) StreamSearch(req *proto.StreamSearchRequest, srv proto.WebserverService_StreamSearchServer) error { if m.fail { return fmt.Errorf("%s", m.errMsg) } return srv.Send(&proto.StreamSearchResponse{ ResponseChunk: &proto.SearchResponse{ Files: m.files, Stats: m.stats, }, }) } func ToProtoFileMatches(matches []zoekt.FileMatch) []*proto.FileMatch { out := make([]*proto.FileMatch, 0, len(matches)) for _, m := range matches { out = append(out, &proto.FileMatch{ FileName: []byte(m.FileName), Score: m.Score, }) } return out } func StartFailingBufconnServer(t *testing.T, errMsg string) (*grpc.ClientConn, string) { lis := bufconn.Listen(1024 * 1024) s := grpc.NewServer() mock := &MockZoektGRPCServer{ fail: true, errMsg: errMsg, } proto.RegisterWebserverServiceServer(s, mock) go func() { _ = s.Serve(lis) }() dialer := func(context.Context, string) (net.Conn, error) { return lis.Dial() } // Add http:// prefix to make it compatible with parseEndpoint endpoint := fmt.Sprintf("http://bufnet-%p", lis) // Extract the host part to use as the target for the connection hostPart := strings.TrimPrefix(endpoint, "http://") hostPart = strings.Split(hostPart, "/")[0] conn, err := NewTestGrpcClient(context.Background(), hostPart, dialer) require.NoError(t, err) return conn, endpoint } type mockStream struct { responses []*proto.StreamSearchResponse err error index int } func (m *mockStream) Recv() (*proto.StreamSearchResponse, error) { if m.err != nil { return nil, m.err } if m.index >= len(m.responses) { return nil, io.EOF } resp := m.responses[m.index] m.index++ return resp, nil } func NewTestGrpcClient(ctx context.Context, target string, dialer func(context.Context, string) (net.Conn, error)) (*grpc.ClientConn, error) { if !strings.HasPrefix(target, "passthrough:///") { target = "passthrough:///" + target } return grpc.NewClient(target, grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(insecure.NewCredentials()), ) } // GetHostFromEndpoint is a helper function to extract host part from endpoint URL func GetHostFromEndpoint(endpoint string) string { // Remove protocol prefix if present cleanEndpoint := endpoint if strings.HasPrefix(endpoint, "http://") { cleanEndpoint = strings.TrimPrefix(endpoint, "http://") } else if strings.HasPrefix(endpoint, "https://") { cleanEndpoint = strings.TrimPrefix(endpoint, "https://") } // Extract host part (before any path) hostPart := strings.Split(cleanEndpoint, "/")[0] return hostPart }