internal/task_request/task_request.go (172 lines of code) (raw):

package task_request //nolint:staticcheck import ( "bytes" "context" "encoding/json" "fmt" "io" "log/slog" "net/http" "net/url" "strconv" "time" "gitlab.com/gitlab-org/gitlab-zoekt-indexer/internal/authentication" "gitlab.com/gitlab-org/gitlab-zoekt-indexer/internal/disk_stats" "gitlab.com/gitlab-org/gitlab-zoekt-indexer/internal/indexing_lock" "gitlab.com/gitlab-org/gitlab-zoekt-indexer/internal/server" "gitlab.com/gitlab-org/gitlab-zoekt-indexer/internal/task_request_response" ) const ( apiSecretHeaderName = "Gitlab-Shell-Api-Request" //nolint:gosec //TODO: https://gitlab.com/gitlab-org/gitlab/-/issues/426557 apiBasePath = "/api/v4/internal/search/zoekt" jwtIssuer = "gitlab-shell" // Needs to be shell so that rails verifies the payload jwtTTL = time.Minute tickerInterval = 100 * time.Millisecond defaultTaskRequestInterval = 10 * time.Second requestTimeout = 30 * time.Second ) type NewTaskRequestTimerParams struct { IndexDir string NodeName string NodeUUID string Version string SelfURL string SearchURL string GitlabURL string Secret []byte Concurrency int IndexingLock *indexing_lock.IndexingLock } type taskRequestTimer struct { indexDir string nodeName string nodeUUID string version string gitLabURL string selfURL string searchURL string authenticator func() (string, error) concurrency int indexingLock *indexing_lock.IndexingLock } func NewTaskRequestTimer(p *NewTaskRequestTimerParams) (*taskRequestTimer, error) { authenticator := authentication.NewAuth(jwtIssuer, jwtTTL, p.Secret).GenerateJWT timer := taskRequestTimer{ indexDir: p.IndexDir, nodeName: p.NodeName, nodeUUID: p.NodeUUID, version: p.Version, gitLabURL: p.GitlabURL, selfURL: p.SelfURL, searchURL: p.SearchURL, authenticator: authenticator, concurrency: p.Concurrency, indexingLock: p.IndexingLock, } return &timer, nil } func (h *taskRequestTimer) Start(ctx context.Context, s *server.IndexServer) error { ticker := time.NewTicker(tickerInterval) defer ticker.Stop() lastSentRequestAt := time.Now() pullFrequency := defaultTaskRequestInterval slog.Info("starting task request timer", "interval", pullFrequency.Seconds()) localCtx, localCancel := context.WithCancel(ctx) defer func() { localCancel() }() // use the current value of 'localCancel' for { select { case <-ctx.Done(): return ctx.Err() //nolint:govet // It is a false positive case <-ticker.C: delta := time.Since(lastSentRequestAt) if delta >= pullFrequency { body, err := h.SendRequest(localCtx) if err != nil { // blocking slog.Error("error while sending task request", "err", err) } else { res := task_request_response.Process(localCtx, body, s) if res.Interval != 0 { pullFrequency = res.Interval } if res.StopIndexing { localCancel() // stop all currently running indexing // construct new context for all future indexing localCtx, localCancel = context.WithCancel(ctx) //nolint:govet // It is a false positive } } lastSentRequestAt = time.Now() slog.Info("ticker TaskRequest information", "delta", delta.Seconds(), "lastSentRequestAt", lastSentRequestAt, "interval", pullFrequency.Seconds()) } } } } func (h *taskRequestTimer) SendRequest(ctx context.Context) ([]byte, error) { fullURL, err := url.JoinPath(h.gitLabURL, apiBasePath, h.nodeUUID, "heartbeat") response := []byte{} if err != nil { return response, err } payload, err := h.buildRequestPayload() if err != nil { return response, err } requestCtx, requestCancel := context.WithTimeout(ctx, requestTimeout) defer requestCancel() req, err := http.NewRequestWithContext(requestCtx, http.MethodPost, fullURL, bytes.NewReader(payload)) if err != nil { return response, err } tokenString, authErr := h.authenticator() if authErr != nil { return response, authErr } req.Header = http.Header{ "Accept": {"application/json"}, apiSecretHeaderName: {tokenString}, } if payload != nil { req.Header.Set("Content-Type", "application/json") } client := http.Client{} resp, err := client.Do(req) //nolint:bodyclose if err != nil { return response, err } body, err := io.ReadAll(resp.Body) if err != nil { return body, err } if resp.StatusCode != http.StatusOK { return body, fmt.Errorf("TaskRequest responded with [%d]", resp.StatusCode) } slog.Debug("TaskRequest", "status", resp.StatusCode, "body", body) slog.Info("TaskRequest", "status", resp.StatusCode) return body, nil } func (h *taskRequestTimer) buildRequestPayload() ([]byte, error) { d, err := disk_stats.DiskUsage(h.indexDir) if err != nil { return []byte{}, err } inProgressCount := strconv.Itoa(h.indexingLock.InProgressCount()) params := map[string]interface{}{ "node.url": h.selfURL, "node.search_url": h.searchURL, "node.name": h.nodeName, "node.version": h.version, "node.task_count": inProgressCount, "node.concurrency": strconv.Itoa(h.concurrency), "disk.all": strconv.FormatUint(d.All, 10), "disk.free": strconv.FormatUint(d.Free, 10), "disk.used": strconv.FormatUint(d.Used, 10), } if d.Indexed == -1 { params["error"] = "Error in calculating disk.indexed" } else { params["disk.indexed"] = strconv.FormatInt(d.Indexed, 10) } payload, err := json.Marshal(params) if err != nil { return []byte{}, err } return payload, nil }