internal/search/search.go (198 lines of code) (raw):
package search
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"sort"
"strings"
"sync"
"time"
"google.golang.org/grpc"
)
const (
defaultSearchTimeout = "60s"
defaultMaxLineMatchWindow = uint32(5000)
)
var (
forwardedHeaders = []string{
"Authorization",
}
)
func (searcher *Searcher) Search(r *http.Request) (*SearchResult, error) {
req, err := NewSearchRequest(r)
if err != nil {
return nil, fmt.Errorf("error building search request: %w", err)
}
result, err := searcher.multiNodeSearch(r.Context(), req)
if err != nil {
return nil, fmt.Errorf("multi node search failed: %w", err)
}
return result, nil
}
func NewSearcher(client *http.Client) *Searcher {
return &Searcher{
Client: client,
GrpcConns: make(map[string]*grpc.ClientConn),
}
}
func NewSearchRequest(r *http.Request) (*SearchRequest, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}
var req SearchRequest
err = json.Unmarshal(body, &req) // nolint:musttag
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %w", err)
}
if len(req.ForwardTo) == 0 {
return nil, errors.New("no forward-to connections specified")
}
for _, conn := range req.ForwardTo {
if conn.Endpoint == "" {
return nil, errors.New("forward-to endpoint is empty")
}
if len(conn.RepoIds) == 0 {
return nil, errors.New("no repo IDs specified for forward-to connection")
}
}
for _, conn := range req.ForwardTo {
if conn.Endpoint == "" {
return nil, errors.New("forward-to endpoint is empty")
}
if len(conn.RepoIds) == 0 {
return nil, errors.New("no repo IDs specified for forward-to connection")
}
}
if req.Query == "" {
return nil, errors.New("search query is empty")
}
if req.TimeoutString == "" {
req.TimeoutString = defaultSearchTimeout
}
timeout, err := time.ParseDuration(req.TimeoutString)
if err != nil {
return nil, fmt.Errorf("failed to parse Timeout: %v with error %w", req.TimeoutString, err)
}
req.Timeout = timeout
if req.Options.MaxLineMatchWindow == 0 {
if req.Options.TotalMaxMatchCount == 0 {
req.Options.MaxLineMatchWindow = defaultMaxLineMatchWindow
} else {
req.Options.MaxLineMatchWindow = req.Options.TotalMaxMatchCount
}
}
req.Headers = map[string]string{
"Content-Type": "application/json",
"Accept": "application/json",
}
for _, header := range forwardedHeaders {
if values := r.Header.Values(header); len(values) > 0 {
req.Headers[header] = strings.Join(values, ",")
}
}
return &req, nil
}
func (searcher *Searcher) multiNodeSearch(ctx context.Context, s *SearchRequest) (*SearchResult, error) {
var wg sync.WaitGroup
responses := make(chan ZoektResponse, len(s.ForwardTo))
timeoutCtx, cancel := context.WithTimeout(ctx, s.Timeout)
defer cancel()
for i := range s.ForwardTo {
wg.Add(1)
go func(i int) {
defer wg.Done()
resp := searcher.DoSearch(timeoutCtx, s, &s.ForwardTo[i])
select {
case responses <- resp:
case <-timeoutCtx.Done():
}
}(i)
}
go func() {
wg.Wait()
close(responses)
}()
result := combineResults(timeoutCtx, responses, s, cancel)
if result.TimedOut && len(result.Result.Files) == 0 {
return nil, fmt.Errorf("search timed out")
}
if len(result.Failures) == len(s.ForwardTo) {
var errMsgs []error
for _, failure := range result.Failures {
errMsgs = append(errMsgs, fmt.Errorf("%s: %s", failure.Endpoint, failure.Error))
}
combinedErr := errors.Join(errMsgs...)
return nil, fmt.Errorf("all searches failed: %w", combinedErr)
}
result.sort()
if s.Options.MaxFileMatchResults > 0 && len(result.Result.Files) > int(s.Options.MaxFileMatchResults) {
result.Result.Files = result.Result.Files[:s.Options.MaxFileMatchResults]
}
if s.Options.MaxLineMatchResults > 0 {
remaining := int(s.Options.MaxLineMatchResults)
maxPerFile := int(s.Options.MaxLineMatchResultsPerFile)
for i := range result.Result.Files {
lines := result.Result.Files[i].LineMatches
if maxPerFile > 0 && len(lines) > maxPerFile {
lines = lines[:maxPerFile]
}
if len(lines) > remaining {
lines = lines[:remaining]
}
result.Result.Files[i].LineMatches = lines
remaining -= len(lines)
if remaining <= 0 {
result.Result.Files = result.Result.Files[:i+1]
break
}
}
}
return result, nil
}
func combineResults(ctx context.Context, ch <-chan ZoektResponse, s *SearchRequest, cancel context.CancelFunc) *SearchResult {
combinedResult := &SearchResult{}
var totalFileCount int
var totalLineCount int
for {
select {
case resp, ok := <-ch:
if !ok {
return combinedResult
}
if resp.Error != nil {
combinedResult.Failures = append(combinedResult.Failures, SearchFailure{
Error: resp.Error.Error(),
Endpoint: resp.Endpoint,
})
continue
}
result := resp.Result
combinedResult.Result.FileCount += uint32(result.FileCount) // #nosec G115
combinedResult.Result.MatchCount += uint32(result.MatchCount) // #nosec G115
combinedResult.Result.NgramMatches += uint32(result.NgramMatches) // #nosec G115
for _, f := range result.Files {
lineCount := len(f.LineMatches)
totalLineCount += lineCount
totalFileCount++
combinedResult.Result.Files = append(combinedResult.Result.Files, f)
if (s.Options.MaxFileMatchWindow > 0 && totalFileCount >= int(s.Options.MaxFileMatchWindow)) ||
(s.Options.MaxLineMatchWindow > 0 && totalLineCount >= int(s.Options.MaxLineMatchWindow)) {
cancel()
return combinedResult
}
}
case <-ctx.Done():
combinedResult.TimedOut = true
return combinedResult
}
}
}
func (s *SearchResult) sort() {
sort.Slice(s.Result.Files, func(i, j int) bool {
return s.Result.Files[i].Score > s.Result.Files[j].Score
})
}