internal/search/grpc.go (274 lines of code) (raw):
// Package search provides gRPC client functionality for distributed code search using Zoekt.
// It handles establishing and managing gRPC connections to Zoekt search backends,
// constructing search queries with repository filtering, streaming search results,
// and converting between Zoekt's gRPC protocol buffers and native data structures.
// The package supports forwarding HTTP headers, endpoint parsing, and connection pooling
// to efficiently distribute search requests across multiple Zoekt instances.
package search
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"log/slog"
"strconv"
"strings"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"github.com/RoaringBitmap/roaring/v2"
"github.com/sourcegraph/zoekt"
zproto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1"
zquery "github.com/sourcegraph/zoekt/query"
)
type SearchStream interface {
Recv() (*zproto.StreamSearchResponse, error)
}
func withForwardedHeaders(ctx context.Context, r *SearchRequest) context.Context {
for key, value := range r.Headers {
ctx = metadata.AppendToOutgoingContext(ctx, strings.ToLower(key), value)
}
return ctx
}
func (searcher *Searcher) DoSearch(ctx context.Context, r *SearchRequest, c *Conn) ZoektResponse {
conn, err := searcher.getGrpcConn(c.Endpoint)
if err != nil {
return ZoektResponse{
Error: fmt.Errorf("failed to create gRPC connection: %w", err),
Endpoint: c.Endpoint,
}
}
client := zproto.NewWebserverServiceClient(conn)
ctx = withForwardedHeaders(ctx, r)
endpointInfo, parseErr := parseEndpoint(c.Endpoint)
if parseErr != nil {
return ZoektResponse{
Error: fmt.Errorf("failed to parse endpoint: %w", parseErr),
Endpoint: c.Endpoint,
}
}
if endpointInfo.Path != "" {
ctx = metadata.AppendToOutgoingContext(ctx, "x-forwarded-path", endpointInfo.Path)
}
textQuery, err := zquery.Parse(r.Query)
if err != nil {
return ZoektResponse{Error: fmt.Errorf("query parse failed: %w", err)}
}
var children []*zproto.Q
children = append(children, zquery.QToProto(textQuery))
bitmap := roaring.NewBitmap()
for _, id := range c.RepoIds {
bitmap.Add(id)
}
var buf bytes.Buffer
if _, writeErr := bitmap.WriteTo(&buf); writeErr != nil {
return ZoektResponse{
Error: fmt.Errorf("failed to serialize repo IDs: %w", writeErr),
Endpoint: c.Endpoint,
}
}
children = append(children, &zproto.Q{
Query: &zproto.Q_RepoIds{
RepoIds: &zproto.RepoIds{
Repos: buf.Bytes(),
},
},
})
finalQuery := &zproto.Q{
Query: &zproto.Q_And{
And: &zproto.And{Children: children},
},
}
searchReq := &zproto.SearchRequest{
Query: finalQuery,
Opts: &zproto.SearchOptions{
NumContextLines: int64(r.Options.NumContextLines),
TotalMaxMatchCount: int64(r.Options.TotalMaxMatchCount),
},
}
stream, err := client.StreamSearch(ctx, &zproto.StreamSearchRequest{Request: searchReq})
if err != nil {
return ZoektResponse{Error: err}
}
return handleGrpcSearchStream(ctx, stream, c.Endpoint)
}
func handleGrpcSearchStream(ctx context.Context, stream SearchStream, endpoint string) ZoektResponse {
var matches []*zproto.FileMatch
var stats zproto.Stats
for {
resp, err := stream.Recv()
if ctxErr := ctx.Err(); ctxErr != nil {
return ZoektResponse{
Error: ctxErr,
Endpoint: endpoint,
}
}
if err != nil {
if errors.Is(err, io.EOF) {
return ZoektResponse{
Endpoint: endpoint,
Result: zoekt.SearchResult{
Stats: zoekt.Stats{
FileCount: int(stats.FileCount),
MatchCount: int(stats.MatchCount),
NgramMatches: int(stats.NgramMatches),
},
Files: convertGrpcFiles(matches),
},
}
}
return ZoektResponse{
Error: fmt.Errorf("error receiving from upstream stream: %w", err),
Endpoint: endpoint,
}
}
if resp == nil {
return ZoektResponse{
Error: fmt.Errorf("received nil response from stream without EOF"),
Endpoint: endpoint,
}
}
chunk := resp.GetResponseChunk()
if chunk == nil {
continue
}
matches = append(matches, chunk.Files...)
if chunk.Stats != nil {
stats.FileCount += chunk.Stats.FileCount
stats.MatchCount += chunk.Stats.MatchCount
stats.NgramMatches += chunk.Stats.NgramMatches
}
if stats.NgramMatches == 0 {
var ngramMatchCount int
for _, file := range matches {
for _, lm := range file.LineMatches {
ngramMatchCount += len(lm.LineFragments)
}
}
stats.NgramMatches = int64(ngramMatchCount)
}
}
}
func (searcher *Searcher) getGrpcConn(endpoint string) (*grpc.ClientConn, error) {
endpointInfo, parseErr := parseEndpoint(endpoint)
if parseErr != nil {
return nil, fmt.Errorf("failed to parse endpoint: %w", parseErr)
}
searcher.grpcMutex.Lock()
defer searcher.grpcMutex.Unlock()
if searcher.GrpcConns == nil {
searcher.GrpcConns = make(map[string]*grpc.ClientConn)
}
if conn, ok := searcher.GrpcConns[endpointInfo.HostPort]; ok {
return conn, nil
}
var opts []grpc.DialOption
if endpointInfo.Protocol == "https" {
// Setup TLS credentials with system certificates
tlsConfig := &tls.Config{
ServerName: searcher.TLSServerName,
MinVersion: tls.VersionTLS12,
}
// If no server name is explicitly set, use the host from the endpoint
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
// Extract hostname from HostPort (remove port if present)
hostParts := strings.Split(endpointInfo.HostPort, ":")
tlsConfig.ServerName = hostParts[0]
}
// Use system certificates
systemPool, err := x509.SystemCertPool()
if err != nil {
slog.Warn("could not load system certificate pool", "error", err)
} else if systemPool != nil {
tlsConfig.RootCAs = systemPool
}
creds := credentials.NewTLS(tlsConfig)
opts = append(opts, grpc.WithTransportCredentials(creds))
} else {
// Use insecure connection
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
// Create connection only using host:port
newConn, err := grpc.NewClient(endpointInfo.HostPort, opts...)
if err != nil {
return nil, err
}
searcher.GrpcConns[endpointInfo.HostPort] = newConn
return newConn, nil
}
func convertGrpcFiles(grpcFiles []*zproto.FileMatch) []zoekt.FileMatch {
files := make([]zoekt.FileMatch, len(grpcFiles))
for i, gf := range grpcFiles {
var repoID uint32
if parsed, err := strconv.ParseUint(gf.Repository, 10, 64); err == nil && parsed <= uint64(^uint32(0)) {
repoID = uint32(parsed)
}
files[i] = zoekt.FileMatch{
FileName: string(gf.FileName),
Repository: gf.Repository,
Branches: gf.Branches,
LineMatches: convertGrpcLineMatches(gf.LineMatches),
Checksum: gf.Checksum,
RepositoryID: repoID,
Score: gf.Score,
Language: gf.Language,
Version: gf.Version,
}
}
return files
}
func convertGrpcLineMatches(grpcLines []*zproto.LineMatch) []zoekt.LineMatch {
lines := make([]zoekt.LineMatch, len(grpcLines))
for i, gl := range grpcLines {
lines[i] = zoekt.LineMatch{
Line: gl.Line,
Before: gl.Before,
After: gl.After,
LineStart: int(gl.LineStart),
LineEnd: int(gl.LineEnd),
LineNumber: int(gl.LineNumber),
FileName: gl.FileName,
Score: gl.Score,
DebugScore: gl.DebugScore,
LineFragments: convertGrpcLineFragments(gl.LineFragments),
}
}
return lines
}
func convertGrpcLineFragments(fragments []*zproto.LineFragmentMatch) []zoekt.LineFragmentMatch {
out := make([]zoekt.LineFragmentMatch, len(fragments))
for i, f := range fragments {
out[i] = zoekt.LineFragmentMatch{
LineOffset: int(f.LineOffset),
Offset: f.Offset,
MatchLength: int(f.MatchLength),
}
}
return out
}
func (searcher *Searcher) Close() error {
searcher.grpcMutex.Lock()
defer searcher.grpcMutex.Unlock()
var errs []error
for endpoint, conn := range searcher.GrpcConns {
if err := conn.Close(); err != nil {
errs = append(errs, fmt.Errorf("error closing gRPC connection to %s: %w", endpoint, err))
}
}
searcher.GrpcConns = make(map[string]*grpc.ClientConn)
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
func parseEndpoint(endpoint string) (EndpointInfo, error) {
// Store the original endpoint
info := EndpointInfo{
Original: endpoint,
}
// Determine protocol
if strings.HasPrefix(endpoint, "https://") {
info.Protocol = "https"
} else if strings.HasPrefix(endpoint, "http://") {
info.Protocol = "http"
} else {
return info, errors.New("unknown protocol")
}
endpoint = strings.TrimPrefix(endpoint, info.Protocol+"://")
// Split into host:port and path parts
parts := strings.SplitN(endpoint, "/", 2)
info.HostPort = parts[0]
if len(parts) > 1 {
info.Path = "/" + parts[1]
}
return info, nil
}