internal/task_request/task_request.go (172 lines of code) (raw):
package task_request //nolint:staticcheck
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strconv"
"time"
"gitlab.com/gitlab-org/gitlab-zoekt-indexer/internal/authentication"
"gitlab.com/gitlab-org/gitlab-zoekt-indexer/internal/disk_stats"
"gitlab.com/gitlab-org/gitlab-zoekt-indexer/internal/indexing_lock"
"gitlab.com/gitlab-org/gitlab-zoekt-indexer/internal/server"
"gitlab.com/gitlab-org/gitlab-zoekt-indexer/internal/task_request_response"
)
const (
apiSecretHeaderName = "Gitlab-Shell-Api-Request" //nolint:gosec //TODO: https://gitlab.com/gitlab-org/gitlab/-/issues/426557
apiBasePath = "/api/v4/internal/search/zoekt"
jwtIssuer = "gitlab-shell" // Needs to be shell so that rails verifies the payload
jwtTTL = time.Minute
tickerInterval = 100 * time.Millisecond
defaultTaskRequestInterval = 10 * time.Second
requestTimeout = 30 * time.Second
)
type NewTaskRequestTimerParams struct {
IndexDir string
NodeName string
NodeUUID string
Version string
SelfURL string
SearchURL string
GitlabURL string
Secret []byte
Concurrency int
IndexingLock *indexing_lock.IndexingLock
}
type taskRequestTimer struct {
indexDir string
nodeName string
nodeUUID string
version string
gitLabURL string
selfURL string
searchURL string
authenticator func() (string, error)
concurrency int
indexingLock *indexing_lock.IndexingLock
}
func NewTaskRequestTimer(p *NewTaskRequestTimerParams) (*taskRequestTimer, error) {
authenticator := authentication.NewAuth(jwtIssuer, jwtTTL, p.Secret).GenerateJWT
timer := taskRequestTimer{
indexDir: p.IndexDir,
nodeName: p.NodeName,
nodeUUID: p.NodeUUID,
version: p.Version,
gitLabURL: p.GitlabURL,
selfURL: p.SelfURL,
searchURL: p.SearchURL,
authenticator: authenticator,
concurrency: p.Concurrency,
indexingLock: p.IndexingLock,
}
return &timer, nil
}
func (h *taskRequestTimer) Start(ctx context.Context, s *server.IndexServer) error {
ticker := time.NewTicker(tickerInterval)
defer ticker.Stop()
lastSentRequestAt := time.Now()
pullFrequency := defaultTaskRequestInterval
slog.Info("starting task request timer", "interval", pullFrequency.Seconds())
localCtx, localCancel := context.WithCancel(ctx)
defer func() { localCancel() }() // use the current value of 'localCancel'
for {
select {
case <-ctx.Done():
return ctx.Err() //nolint:govet // It is a false positive
case <-ticker.C:
delta := time.Since(lastSentRequestAt)
if delta >= pullFrequency {
body, err := h.SendRequest(localCtx)
if err != nil { // blocking
slog.Error("error while sending task request", "err", err)
} else {
res := task_request_response.Process(localCtx, body, s)
if res.Interval != 0 {
pullFrequency = res.Interval
}
if res.StopIndexing {
localCancel() // stop all currently running indexing
// construct new context for all future indexing
localCtx, localCancel = context.WithCancel(ctx) //nolint:govet // It is a false positive
}
}
lastSentRequestAt = time.Now()
slog.Info("ticker TaskRequest information", "delta", delta.Seconds(), "lastSentRequestAt", lastSentRequestAt, "interval", pullFrequency.Seconds())
}
}
}
}
func (h *taskRequestTimer) SendRequest(ctx context.Context) ([]byte, error) {
fullURL, err := url.JoinPath(h.gitLabURL, apiBasePath, h.nodeUUID, "heartbeat")
response := []byte{}
if err != nil {
return response, err
}
payload, err := h.buildRequestPayload()
if err != nil {
return response, err
}
requestCtx, requestCancel := context.WithTimeout(ctx, requestTimeout)
defer requestCancel()
req, err := http.NewRequestWithContext(requestCtx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil {
return response, err
}
tokenString, authErr := h.authenticator()
if authErr != nil {
return response, authErr
}
req.Header = http.Header{
"Accept": {"application/json"},
apiSecretHeaderName: {tokenString},
}
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
client := http.Client{}
resp, err := client.Do(req) //nolint:bodyclose
if err != nil {
return response, err
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return body, err
}
if resp.StatusCode != http.StatusOK {
return body, fmt.Errorf("TaskRequest responded with [%d]", resp.StatusCode)
}
slog.Debug("TaskRequest", "status", resp.StatusCode, "body", body)
slog.Info("TaskRequest", "status", resp.StatusCode)
return body, nil
}
func (h *taskRequestTimer) buildRequestPayload() ([]byte, error) {
d, err := disk_stats.DiskUsage(h.indexDir)
if err != nil {
return []byte{}, err
}
inProgressCount := strconv.Itoa(h.indexingLock.InProgressCount())
params := map[string]interface{}{
"node.url": h.selfURL,
"node.search_url": h.searchURL,
"node.name": h.nodeName,
"node.version": h.version,
"node.task_count": inProgressCount,
"node.concurrency": strconv.Itoa(h.concurrency),
"disk.all": strconv.FormatUint(d.All, 10),
"disk.free": strconv.FormatUint(d.Free, 10),
"disk.used": strconv.FormatUint(d.Used, 10),
}
if d.Indexed == -1 {
params["error"] = "Error in calculating disk.indexed"
} else {
params["disk.indexed"] = strconv.FormatInt(d.Indexed, 10)
}
payload, err := json.Marshal(params)
if err != nil {
return []byte{}, err
}
return payload, nil
}