internal/search/search.go (198 lines of code) (raw):

package search import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "sort" "strings" "sync" "time" "google.golang.org/grpc" ) const ( defaultSearchTimeout = "60s" defaultMaxLineMatchWindow = uint32(5000) ) var ( forwardedHeaders = []string{ "Authorization", } ) func (searcher *Searcher) Search(r *http.Request) (*SearchResult, error) { req, err := NewSearchRequest(r) if err != nil { return nil, fmt.Errorf("error building search request: %w", err) } result, err := searcher.multiNodeSearch(r.Context(), req) if err != nil { return nil, fmt.Errorf("multi node search failed: %w", err) } return result, nil } func NewSearcher(client *http.Client) *Searcher { return &Searcher{ Client: client, GrpcConns: make(map[string]*grpc.ClientConn), } } func NewSearchRequest(r *http.Request) (*SearchRequest, error) { body, err := io.ReadAll(r.Body) if err != nil { return nil, fmt.Errorf("failed to read request body: %w", err) } var req SearchRequest err = json.Unmarshal(body, &req) // nolint:musttag if err != nil { return nil, fmt.Errorf("failed to parse request body: %w", err) } if len(req.ForwardTo) == 0 { return nil, errors.New("no forward-to connections specified") } for _, conn := range req.ForwardTo { if conn.Endpoint == "" { return nil, errors.New("forward-to endpoint is empty") } if len(conn.RepoIds) == 0 { return nil, errors.New("no repo IDs specified for forward-to connection") } } for _, conn := range req.ForwardTo { if conn.Endpoint == "" { return nil, errors.New("forward-to endpoint is empty") } if len(conn.RepoIds) == 0 { return nil, errors.New("no repo IDs specified for forward-to connection") } } if req.Query == "" { return nil, errors.New("search query is empty") } if req.TimeoutString == "" { req.TimeoutString = defaultSearchTimeout } timeout, err := time.ParseDuration(req.TimeoutString) if err != nil { return nil, fmt.Errorf("failed to parse Timeout: %v with error %w", req.TimeoutString, err) } req.Timeout = timeout if req.Options.MaxLineMatchWindow == 0 { if req.Options.TotalMaxMatchCount == 0 { req.Options.MaxLineMatchWindow = defaultMaxLineMatchWindow } else { req.Options.MaxLineMatchWindow = req.Options.TotalMaxMatchCount } } req.Headers = map[string]string{ "Content-Type": "application/json", "Accept": "application/json", } for _, header := range forwardedHeaders { if values := r.Header.Values(header); len(values) > 0 { req.Headers[header] = strings.Join(values, ",") } } return &req, nil } func (searcher *Searcher) multiNodeSearch(ctx context.Context, s *SearchRequest) (*SearchResult, error) { var wg sync.WaitGroup responses := make(chan ZoektResponse, len(s.ForwardTo)) timeoutCtx, cancel := context.WithTimeout(ctx, s.Timeout) defer cancel() for i := range s.ForwardTo { wg.Add(1) go func(i int) { defer wg.Done() resp := searcher.DoSearch(timeoutCtx, s, &s.ForwardTo[i]) select { case responses <- resp: case <-timeoutCtx.Done(): } }(i) } go func() { wg.Wait() close(responses) }() result := combineResults(timeoutCtx, responses, s, cancel) if result.TimedOut && len(result.Result.Files) == 0 { return nil, fmt.Errorf("search timed out") } if len(result.Failures) == len(s.ForwardTo) { var errMsgs []error for _, failure := range result.Failures { errMsgs = append(errMsgs, fmt.Errorf("%s: %s", failure.Endpoint, failure.Error)) } combinedErr := errors.Join(errMsgs...) return nil, fmt.Errorf("all searches failed: %w", combinedErr) } result.sort() if s.Options.MaxFileMatchResults > 0 && len(result.Result.Files) > int(s.Options.MaxFileMatchResults) { result.Result.Files = result.Result.Files[:s.Options.MaxFileMatchResults] } if s.Options.MaxLineMatchResults > 0 { remaining := int(s.Options.MaxLineMatchResults) maxPerFile := int(s.Options.MaxLineMatchResultsPerFile) for i := range result.Result.Files { lines := result.Result.Files[i].LineMatches if maxPerFile > 0 && len(lines) > maxPerFile { lines = lines[:maxPerFile] } if len(lines) > remaining { lines = lines[:remaining] } result.Result.Files[i].LineMatches = lines remaining -= len(lines) if remaining <= 0 { result.Result.Files = result.Result.Files[:i+1] break } } } return result, nil } func combineResults(ctx context.Context, ch <-chan ZoektResponse, s *SearchRequest, cancel context.CancelFunc) *SearchResult { combinedResult := &SearchResult{} var totalFileCount int var totalLineCount int for { select { case resp, ok := <-ch: if !ok { return combinedResult } if resp.Error != nil { combinedResult.Failures = append(combinedResult.Failures, SearchFailure{ Error: resp.Error.Error(), Endpoint: resp.Endpoint, }) continue } result := resp.Result combinedResult.Result.FileCount += uint32(result.FileCount) // #nosec G115 combinedResult.Result.MatchCount += uint32(result.MatchCount) // #nosec G115 combinedResult.Result.NgramMatches += uint32(result.NgramMatches) // #nosec G115 for _, f := range result.Files { lineCount := len(f.LineMatches) totalLineCount += lineCount totalFileCount++ combinedResult.Result.Files = append(combinedResult.Result.Files, f) if (s.Options.MaxFileMatchWindow > 0 && totalFileCount >= int(s.Options.MaxFileMatchWindow)) || (s.Options.MaxLineMatchWindow > 0 && totalLineCount >= int(s.Options.MaxLineMatchWindow)) { cancel() return combinedResult } } case <-ctx.Done(): combinedResult.TimedOut = true return combinedResult } } } func (s *SearchResult) sort() { sort.Slice(s.Result.Files, func(i, j int) bool { return s.Result.Files[i].Score > s.Result.Files[j].Score }) }