internal/search/grpc.go (274 lines of code) (raw):

// Package search provides gRPC client functionality for distributed code search using Zoekt. // It handles establishing and managing gRPC connections to Zoekt search backends, // constructing search queries with repository filtering, streaming search results, // and converting between Zoekt's gRPC protocol buffers and native data structures. // The package supports forwarding HTTP headers, endpoint parsing, and connection pooling // to efficiently distribute search requests across multiple Zoekt instances. package search import ( "bytes" "context" "crypto/tls" "crypto/x509" "errors" "fmt" "io" "log/slog" "strconv" "strings" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" "github.com/RoaringBitmap/roaring/v2" "github.com/sourcegraph/zoekt" zproto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" zquery "github.com/sourcegraph/zoekt/query" ) type SearchStream interface { Recv() (*zproto.StreamSearchResponse, error) } func withForwardedHeaders(ctx context.Context, r *SearchRequest) context.Context { for key, value := range r.Headers { ctx = metadata.AppendToOutgoingContext(ctx, strings.ToLower(key), value) } return ctx } func (searcher *Searcher) DoSearch(ctx context.Context, r *SearchRequest, c *Conn) ZoektResponse { conn, err := searcher.getGrpcConn(c.Endpoint) if err != nil { return ZoektResponse{ Error: fmt.Errorf("failed to create gRPC connection: %w", err), Endpoint: c.Endpoint, } } client := zproto.NewWebserverServiceClient(conn) ctx = withForwardedHeaders(ctx, r) endpointInfo, parseErr := parseEndpoint(c.Endpoint) if parseErr != nil { return ZoektResponse{ Error: fmt.Errorf("failed to parse endpoint: %w", parseErr), Endpoint: c.Endpoint, } } if endpointInfo.Path != "" { ctx = metadata.AppendToOutgoingContext(ctx, "x-forwarded-path", endpointInfo.Path) } textQuery, err := zquery.Parse(r.Query) if err != nil { return ZoektResponse{Error: fmt.Errorf("query parse failed: %w", err)} } var children []*zproto.Q children = append(children, zquery.QToProto(textQuery)) bitmap := roaring.NewBitmap() for _, id := range c.RepoIds { bitmap.Add(id) } var buf bytes.Buffer if _, writeErr := bitmap.WriteTo(&buf); writeErr != nil { return ZoektResponse{ Error: fmt.Errorf("failed to serialize repo IDs: %w", writeErr), Endpoint: c.Endpoint, } } children = append(children, &zproto.Q{ Query: &zproto.Q_RepoIds{ RepoIds: &zproto.RepoIds{ Repos: buf.Bytes(), }, }, }) finalQuery := &zproto.Q{ Query: &zproto.Q_And{ And: &zproto.And{Children: children}, }, } searchReq := &zproto.SearchRequest{ Query: finalQuery, Opts: &zproto.SearchOptions{ NumContextLines: int64(r.Options.NumContextLines), TotalMaxMatchCount: int64(r.Options.TotalMaxMatchCount), }, } stream, err := client.StreamSearch(ctx, &zproto.StreamSearchRequest{Request: searchReq}) if err != nil { return ZoektResponse{Error: err} } return handleGrpcSearchStream(ctx, stream, c.Endpoint) } func handleGrpcSearchStream(ctx context.Context, stream SearchStream, endpoint string) ZoektResponse { var matches []*zproto.FileMatch var stats zproto.Stats for { resp, err := stream.Recv() if ctxErr := ctx.Err(); ctxErr != nil { return ZoektResponse{ Error: ctxErr, Endpoint: endpoint, } } if err != nil { if errors.Is(err, io.EOF) { return ZoektResponse{ Endpoint: endpoint, Result: zoekt.SearchResult{ Stats: zoekt.Stats{ FileCount: int(stats.FileCount), MatchCount: int(stats.MatchCount), NgramMatches: int(stats.NgramMatches), }, Files: convertGrpcFiles(matches), }, } } return ZoektResponse{ Error: fmt.Errorf("error receiving from upstream stream: %w", err), Endpoint: endpoint, } } if resp == nil { return ZoektResponse{ Error: fmt.Errorf("received nil response from stream without EOF"), Endpoint: endpoint, } } chunk := resp.GetResponseChunk() if chunk == nil { continue } matches = append(matches, chunk.Files...) if chunk.Stats != nil { stats.FileCount += chunk.Stats.FileCount stats.MatchCount += chunk.Stats.MatchCount stats.NgramMatches += chunk.Stats.NgramMatches } if stats.NgramMatches == 0 { var ngramMatchCount int for _, file := range matches { for _, lm := range file.LineMatches { ngramMatchCount += len(lm.LineFragments) } } stats.NgramMatches = int64(ngramMatchCount) } } } func (searcher *Searcher) getGrpcConn(endpoint string) (*grpc.ClientConn, error) { endpointInfo, parseErr := parseEndpoint(endpoint) if parseErr != nil { return nil, fmt.Errorf("failed to parse endpoint: %w", parseErr) } searcher.grpcMutex.Lock() defer searcher.grpcMutex.Unlock() if searcher.GrpcConns == nil { searcher.GrpcConns = make(map[string]*grpc.ClientConn) } if conn, ok := searcher.GrpcConns[endpointInfo.HostPort]; ok { return conn, nil } var opts []grpc.DialOption if endpointInfo.Protocol == "https" { // Setup TLS credentials with system certificates tlsConfig := &tls.Config{ ServerName: searcher.TLSServerName, MinVersion: tls.VersionTLS12, } // If no server name is explicitly set, use the host from the endpoint if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify { // Extract hostname from HostPort (remove port if present) hostParts := strings.Split(endpointInfo.HostPort, ":") tlsConfig.ServerName = hostParts[0] } // Use system certificates systemPool, err := x509.SystemCertPool() if err != nil { slog.Warn("could not load system certificate pool", "error", err) } else if systemPool != nil { tlsConfig.RootCAs = systemPool } creds := credentials.NewTLS(tlsConfig) opts = append(opts, grpc.WithTransportCredentials(creds)) } else { // Use insecure connection opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } // Create connection only using host:port newConn, err := grpc.NewClient(endpointInfo.HostPort, opts...) if err != nil { return nil, err } searcher.GrpcConns[endpointInfo.HostPort] = newConn return newConn, nil } func convertGrpcFiles(grpcFiles []*zproto.FileMatch) []zoekt.FileMatch { files := make([]zoekt.FileMatch, len(grpcFiles)) for i, gf := range grpcFiles { var repoID uint32 if parsed, err := strconv.ParseUint(gf.Repository, 10, 64); err == nil && parsed <= uint64(^uint32(0)) { repoID = uint32(parsed) } files[i] = zoekt.FileMatch{ FileName: string(gf.FileName), Repository: gf.Repository, Branches: gf.Branches, LineMatches: convertGrpcLineMatches(gf.LineMatches), Checksum: gf.Checksum, RepositoryID: repoID, Score: gf.Score, Language: gf.Language, Version: gf.Version, } } return files } func convertGrpcLineMatches(grpcLines []*zproto.LineMatch) []zoekt.LineMatch { lines := make([]zoekt.LineMatch, len(grpcLines)) for i, gl := range grpcLines { lines[i] = zoekt.LineMatch{ Line: gl.Line, Before: gl.Before, After: gl.After, LineStart: int(gl.LineStart), LineEnd: int(gl.LineEnd), LineNumber: int(gl.LineNumber), FileName: gl.FileName, Score: gl.Score, DebugScore: gl.DebugScore, LineFragments: convertGrpcLineFragments(gl.LineFragments), } } return lines } func convertGrpcLineFragments(fragments []*zproto.LineFragmentMatch) []zoekt.LineFragmentMatch { out := make([]zoekt.LineFragmentMatch, len(fragments)) for i, f := range fragments { out[i] = zoekt.LineFragmentMatch{ LineOffset: int(f.LineOffset), Offset: f.Offset, MatchLength: int(f.MatchLength), } } return out } func (searcher *Searcher) Close() error { searcher.grpcMutex.Lock() defer searcher.grpcMutex.Unlock() var errs []error for endpoint, conn := range searcher.GrpcConns { if err := conn.Close(); err != nil { errs = append(errs, fmt.Errorf("error closing gRPC connection to %s: %w", endpoint, err)) } } searcher.GrpcConns = make(map[string]*grpc.ClientConn) if len(errs) > 0 { return errors.Join(errs...) } return nil } func parseEndpoint(endpoint string) (EndpointInfo, error) { // Store the original endpoint info := EndpointInfo{ Original: endpoint, } // Determine protocol if strings.HasPrefix(endpoint, "https://") { info.Protocol = "https" } else if strings.HasPrefix(endpoint, "http://") { info.Protocol = "http" } else { return info, errors.New("unknown protocol") } endpoint = strings.TrimPrefix(endpoint, info.Protocol+"://") // Split into host:port and path parts parts := strings.SplitN(endpoint, "/", 2) info.HostPort = parts[0] if len(parts) > 1 { info.Path = "/" + parts[1] } return info, nil }