internal/pkg/api/handleAck.go (597 lines of code) (raw):
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.
package api
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/miolini/datacounter"
"github.com/rs/zerolog"
"go.elastic.co/apm/module/apmhttp/v2"
"go.elastic.co/apm/v2"
"github.com/elastic/fleet-server/v7/internal/pkg/bulk"
"github.com/elastic/fleet-server/v7/internal/pkg/cache"
"github.com/elastic/fleet-server/v7/internal/pkg/config"
"github.com/elastic/fleet-server/v7/internal/pkg/dl"
"github.com/elastic/fleet-server/v7/internal/pkg/es"
"github.com/elastic/fleet-server/v7/internal/pkg/logger"
"github.com/elastic/fleet-server/v7/internal/pkg/model"
"github.com/elastic/fleet-server/v7/internal/pkg/policy"
"github.com/elastic/fleet-server/v7/internal/pkg/smap"
)
const (
TypeUnenroll = "UNENROLL"
TypeUpgrade = "UPGRADE"
)
var (
ErrUpdatingInactiveAgent = errors.New("updating inactive agent")
)
type HTTPError struct {
Status int
}
func (e *HTTPError) Error() string {
return fmt.Sprintf("%d: %s", e.Status, http.StatusText(e.Status))
}
func NewAckResponse(size int) AckResponse {
return AckResponse{
Action: "acks",
Errors: false,
Items: make([]AckResponseItem, size),
}
}
func (a *AckResponse) setMessage(pos int, status int, message string) {
if status != http.StatusOK {
a.Errors = true
}
a.Items[pos].Status = status
a.Items[pos].Message = &message
}
func (a *AckResponse) SetResult(pos int, status int) {
a.setMessage(pos, status, http.StatusText(status))
}
func (a *AckResponse) SetError(pos int, err error) {
var esErr *es.ErrElastic
if errors.As(err, &esErr) {
a.setMessage(pos, esErr.Status, esErr.Reason)
} else {
a.SetResult(pos, http.StatusInternalServerError)
}
}
type AckT struct {
cfg *config.Server
bulk bulk.Bulk
cache cache.Cache
}
func NewAckT(cfg *config.Server, bulker bulk.Bulk, cache cache.Cache) *AckT {
return &AckT{
cfg: cfg,
bulk: bulker,
cache: cache,
}
}
func (ack *AckT) handleAcks(zlog zerolog.Logger, w http.ResponseWriter, r *http.Request, id string) error {
agent, err := authAgent(r, &id, ack.bulk, ack.cache)
if err != nil {
return err
}
zlog = zlog.With().Str(LogAccessAPIKeyID, agent.AccessAPIKeyID).Logger()
ctx := zlog.WithContext(r.Context())
r = r.WithContext(ctx)
return ack.processRequest(zlog, w, r, agent)
}
func (ack *AckT) processRequest(zlog zerolog.Logger, w http.ResponseWriter, r *http.Request, agent *model.Agent) error {
req, err := ack.validateRequest(zlog, w, r)
if err != nil {
return err
}
zlog = zlog.With().Int("nEvents", len(req.Events)).Logger()
resp, err := ack.handleAckEvents(r.Context(), zlog, agent, req.Events)
span, _ := apm.StartSpan(r.Context(), "response", "write")
defer span.End()
if err != nil {
var herr *HTTPError
if errors.As(err, &herr) {
w.WriteHeader(herr.Status)
} else {
// Non-HTTP error will be handled at higher level
return err
}
}
// Always write response body even if the error HTTP status code was set
data, err := json.Marshal(&resp)
if err != nil {
return fmt.Errorf("handleAcks marshal response: %w", err)
}
var nWritten int
if nWritten, err = w.Write(data); err != nil {
return err
}
cntAcks.bodyOut.Add(uint64(nWritten)) //nolint:gosec // disable G115
return nil
}
func (ack *AckT) validateRequest(zlog zerolog.Logger, w http.ResponseWriter, r *http.Request) (*AckRequest, error) {
span, _ := apm.StartSpan(r.Context(), "validateRequest", "validate")
defer span.End()
body := r.Body
// Limit the size of the body to prevent malicious agent from exhausting RAM in server
if ack.cfg.Limits.AckLimit.MaxBody > 0 {
body = http.MaxBytesReader(w, body, ack.cfg.Limits.AckLimit.MaxBody)
}
readCounter := datacounter.NewReaderCounter(body)
var req AckRequest
dec := json.NewDecoder(readCounter)
if err := dec.Decode(&req); err != nil {
return nil, &BadRequestErr{msg: "unable to decode ack request", nextErr: err}
}
cntAcks.bodyIn.Add(readCounter.Count())
zlog.Trace().Msg("Ack request")
return &req, nil
}
func eventToActionResult(agentID, aType string, namespaces []string, ev AckRequest_Events_Item) (acr model.ActionResult) {
switch aType {
case string(REQUESTDIAGNOSTICS):
event, _ := ev.AsDiagnosticsEvent()
p, _ := json.Marshal(event.Data)
return model.ActionResult{
ActionID: event.ActionId,
AgentID: agentID,
Namespaces: namespaces,
Data: p,
Error: fromPtr(event.Error),
Timestamp: event.Timestamp.Format(time.RFC3339Nano),
}
case string(INPUTACTION):
event, _ := ev.AsInputEvent()
return model.ActionResult{
ActionID: event.ActionId,
AgentID: agentID,
Namespaces: namespaces,
ActionInputType: event.ActionInputType,
StartedAt: event.StartedAt.Format(time.RFC3339Nano),
CompletedAt: event.CompletedAt.Format(time.RFC3339Nano),
ActionData: event.ActionData,
ActionResponse: event.ActionResponse,
Error: fromPtr(event.Error),
Timestamp: event.Timestamp.Format(time.RFC3339Nano),
}
default: // UPGRADE action acks are also handled by handelUpgrade (deprecated func)
event, _ := ev.AsGenericEvent()
return model.ActionResult{
ActionID: event.ActionId,
Namespaces: namespaces,
AgentID: agentID,
Error: fromPtr(event.Error),
Timestamp: event.Timestamp.Format(time.RFC3339Nano),
}
}
}
// handleAckEvents can return:
// 1. AckResponse and nil error, when the whole request is successful
// 2. AckResponse and non-nil error, when the request items had errors
func (ack *AckT) handleAckEvents(ctx context.Context, zlog zerolog.Logger, agent *model.Agent, events []AckRequest_Events_Item) (AckResponse, error) {
span, ctx := apm.StartSpan(ctx, "handleAckEvents", "process")
defer span.End()
var policyAcks []string
var policyIdxs []int
var unenrollIdxs []int
res := NewAckResponse(len(events))
// Error collects the largest error HTTP Status code from all acked events
httpErr := HTTPError{http.StatusOK}
setResult := func(pos, status int) {
if status > httpErr.Status {
httpErr.Status = status
}
res.SetResult(pos, status)
}
setError := func(pos int, err error) {
var esErr *es.ErrElastic
if errors.As(err, &esErr) {
setResult(pos, esErr.Status)
} else {
setResult(pos, http.StatusInternalServerError)
}
res.SetError(pos, err)
e := apm.CaptureError(ctx, err)
e.Send()
}
for n, ev := range events {
event, _ := ev.AsGenericEvent()
span, ctx := apm.StartSpan(ctx, "ackEvent", "process")
span.Context.SetLabel("agent_id", agent.Agent.ID)
span.Context.SetLabel("action_id", event.ActionId)
log := zlog.With().
Str(logger.ActionID, event.ActionId).
Str(logger.AgentID, event.AgentId).
Time("timestamp", event.Timestamp).
Int("n", n).Logger()
log.Info().Msg("ack event")
// Check agent id mismatch
if event.AgentId != "" && event.AgentId != agent.Id {
log.Error().Msg("agent id mismatch")
setResult(n, http.StatusBadRequest)
span.End()
continue
}
// Check if this is the policy change ack
// The policy change acks are handled after actions
if strings.HasPrefix(event.ActionId, "policy:") {
if event.Error == nil {
// only added if no error on action
policyAcks = append(policyAcks, event.ActionId)
policyIdxs = append(policyIdxs, n)
}
// Set OK status, this can be overwritten in case of the errors later when the policy change events acked
setResult(n, http.StatusOK)
span.End()
continue
}
// Process non-policy change actions
// Find matching action by action ID
vSpan, vCtx := apm.StartSpan(ctx, "ackAction", "validate")
action, ok := ack.cache.GetAction(event.ActionId)
if !ok {
// Find action by ID
actions, err := dl.FindAction(vCtx, ack.bulk, event.ActionId)
if err != nil {
log.Error().Err(err).Msg("find action")
setError(n, err)
vSpan.End()
span.End()
continue
}
// Set 404 if action is not found. The agent can retry it later.
if len(actions) == 0 {
log.Error().Msg("no matching action")
setResult(n, http.StatusNotFound)
vSpan.End()
span.End()
continue
}
action = actions[0]
ack.cache.SetAction(action)
}
vSpan.End()
if err := ack.handleActionResult(ctx, zlog, agent, action, ev); err != nil {
setError(n, err)
} else {
setResult(n, http.StatusOK)
}
if event.Error == nil && action.Type == TypeUnenroll {
unenrollIdxs = append(unenrollIdxs, n)
}
span.End()
}
// Process policy acks
if len(policyAcks) > 0 {
if err := ack.handlePolicyChange(ctx, zlog, agent, policyAcks...); err != nil {
for _, idx := range policyIdxs {
setError(idx, err)
}
}
}
// Process unenroll acks
if len(unenrollIdxs) > 0 {
if err := ack.handleUnenroll(ctx, zlog, agent); err != nil {
zlog.Error().Err(err).Msg("handle unenroll event")
// Set errors for each unenroll event
for _, idx := range unenrollIdxs {
setError(idx, err)
}
}
}
// Return both the data and error code
if httpErr.Status > http.StatusOK {
return res, &httpErr
}
return res, nil
}
func (ack *AckT) handleActionResult(ctx context.Context, zlog zerolog.Logger, agent *model.Agent, action model.Action, ev AckRequest_Events_Item) error {
// Build span links for actions
var links []apm.SpanLink
if ack.bulk.HasTracer() && action.Traceparent != "" {
traceCtx, err := apmhttp.ParseTraceparentHeader(action.Traceparent)
if err != nil {
zlog.Trace().Err(err).Msgf("Error parsing traceparent: %s %s", action.Traceparent, err)
} else {
links = []apm.SpanLink{
{
Trace: traceCtx.Trace,
Span: traceCtx.Span,
},
}
}
}
span, ctx := apm.StartSpanOptions(ctx, fmt.Sprintf("Process action result %s", action.Type), "process", apm.SpanOptions{Links: links})
span.Context.SetLabel("action_id", action.Id)
span.Context.SetLabel("agent_id", agent.Agent.ID)
defer span.End()
// Convert ack event to action result document
acr := eventToActionResult(agent.Id, action.Type, action.Namespaces, ev)
// Save action result document
if err := dl.CreateActionResult(ctx, ack.bulk, acr); err != nil {
zlog.Error().Err(err).Str(logger.AgentID, agent.Agent.ID).Str(logger.ActionID, action.Id).Msg("create action result")
return err
}
if action.Type == TypeUpgrade {
event, _ := ev.AsUpgradeEvent()
if err := ack.handleUpgrade(ctx, zlog, agent, event); err != nil {
zlog.Error().Err(err).Str(logger.AgentID, agent.Agent.ID).Str(logger.ActionID, action.Id).Msg("handle upgrade event")
return err
}
}
return nil
}
func (ack *AckT) handlePolicyChange(ctx context.Context, zlog zerolog.Logger, agent *model.Agent, actionIds ...string) error {
span, ctx := apm.StartSpan(ctx, "ackPolicyChanges", "process")
defer span.End()
// If more than one, pick the winner;
// 0) Correct policy id
// 1) Highest revision number
found := false
currRev := agent.PolicyRevisionIdx
vSpan, _ := apm.StartSpan(ctx, "checkPolicyActions", "validate")
for _, a := range actionIds {
rev, ok := policy.RevisionFromString(a)
zlog.Debug().
Str("agent.policyId", agent.PolicyID).
Int64("agent.revisionIdx", currRev).
Str("rev.policyId", rev.PolicyID).
Int64(logger.RevisionIdx, rev.RevisionIdx).
Msg("ack policy revision")
if ok && rev.PolicyID == agent.PolicyID && rev.RevisionIdx > currRev {
found = true
currRev = rev.RevisionIdx
}
}
vSpan.End()
if !found {
return nil
}
for outputName, output := range agent.Outputs {
if output.Type != policy.OutputTypeElasticsearch {
continue
}
err := ack.updateAPIKey(ctx,
zlog,
agent.Id,
output.APIKeyID, output.PermissionsHash, output.ToRetireAPIKeyIds, outputName)
if err != nil {
return err
}
}
err := ack.updateAgentDoc(ctx, zlog,
agent.Id,
currRev,
agent.PolicyID)
if err != nil {
return err
}
return nil
}
func (ack *AckT) updateAPIKey(ctx context.Context,
zlog zerolog.Logger,
agentID string,
apiKeyID, permissionHash string,
toRetireAPIKeyIDs []model.ToRetireAPIKeyIdsItems, outputName string) error {
bulk := ack.bulk
// use output bulker if exists
if outputName != "" {
outputBulk := ack.bulk.GetBulker(outputName)
if outputBulk != nil {
zlog.Debug().Str(logger.PolicyOutputName, outputName).Msg("Using output bulker in updateAPIKey")
bulk = outputBulk
}
}
if apiKeyID != "" {
res, err := bulk.APIKeyRead(ctx, apiKeyID, true)
if err != nil {
if isAgentActive(ctx, zlog, ack.bulk, agentID) {
zlog.Warn().
Err(err).
Str(LogAPIKeyID, apiKeyID).
Str(logger.PolicyOutputName, outputName).
Msg("Failed to read API Key roles")
} else {
// race when API key was invalidated before acking
zlog.Info().
Err(err).
Str(LogAPIKeyID, apiKeyID).
Str(logger.PolicyOutputName, outputName).
Msg("Failed to read invalidated API Key roles")
// prevents future checks
return ErrUpdatingInactiveAgent
}
} else {
clean, removedRolesCount, err := cleanRoles(res.RoleDescriptors)
if err != nil {
zlog.Error().
Err(err).
RawJSON("roles", res.RoleDescriptors).
Str(LogAPIKeyID, apiKeyID).
Msg("Failed to cleanup roles")
} else if removedRolesCount > 0 {
if err := bulk.APIKeyUpdate(ctx, apiKeyID, permissionHash, clean); err != nil {
zlog.Error().Err(err).RawJSON("roles", clean).Str(LogAPIKeyID, apiKeyID).Str(logger.PolicyOutputName, outputName).Msg("Failed to update API Key")
} else {
zlog.Debug().
Str("hash.sha256", permissionHash).
Str(LogAPIKeyID, apiKeyID).
RawJSON("roles", clean).
Int("removedRoles", removedRolesCount).
Str(logger.PolicyOutputName, outputName).
Msg("Updating agent record to pick up reduced roles.")
}
}
}
ack.invalidateAPIKeys(ctx, zlog, toRetireAPIKeyIDs, apiKeyID)
}
return nil
}
func (ack *AckT) updateAgentDoc(ctx context.Context,
zlog zerolog.Logger,
agentID string,
currRev int64,
policyID string,
) error {
span, ctx := apm.StartSpan(ctx, "updateAgentDoc", "update")
defer span.End()
body := makeUpdatePolicyBody(
policyID,
currRev,
)
err := ack.bulk.Update(
ctx,
dl.FleetAgents,
agentID,
body,
bulk.WithRefresh(),
bulk.WithRetryOnConflict(3),
)
zlog.Err(err).
Str(LogPolicyID, policyID).
Int64("policyRevision", currRev).
Msg("ack policy")
if err != nil {
return fmt.Errorf("handlePolicyChange update: %w", err)
}
return nil
}
func cleanRoles(roles json.RawMessage) (json.RawMessage, int, error) {
rr := smap.Map{}
if err := json.Unmarshal(roles, &rr); err != nil {
return nil, 0, fmt.Errorf("failed to unmarshal provided roles: %w", err)
}
keys := make([]string, 0, len(rr))
for k := range rr {
if strings.HasSuffix(k, "-rdstale") {
keys = append(keys, k)
}
}
if len(keys) == 0 {
return roles, 0, nil
}
for _, k := range keys {
delete(rr, k)
}
r, err := json.Marshal(rr)
if err != nil {
return r, len(keys), fmt.Errorf("failed to marshal resulting role definition: %w", err)
}
return r, len(keys), nil
}
func (ack *AckT) invalidateAPIKeys(ctx context.Context, zlog zerolog.Logger, toRetireAPIKeyIDs []model.ToRetireAPIKeyIdsItems, skip string) {
invalidateAPIKeys(ctx, zlog, ack.bulk, toRetireAPIKeyIDs, skip)
}
func (ack *AckT) handleUnenroll(ctx context.Context, zlog zerolog.Logger, agent *model.Agent) error {
span, ctx := apm.StartSpan(ctx, "ackUnenroll", "process")
defer span.End()
apiKeys := agent.APIKeyIDs()
zlog.Info().Any("fleet.policy.apiKeyIDsToRetire", apiKeys).Msg("handleUnenroll invalidate API keys")
ack.invalidateAPIKeys(ctx, zlog, apiKeys, "")
now := time.Now().UTC().Format(time.RFC3339)
doc := bulk.UpdateFields{
dl.FieldActive: false,
dl.FieldUnenrolledAt: now,
dl.FieldUpdatedAt: now,
}
body, err := doc.Marshal()
if err != nil {
return fmt.Errorf("handleUnenroll marshal: %w", err)
}
if err = ack.bulk.Update(ctx, dl.FleetAgents, agent.Id, body, bulk.WithRefresh(), bulk.WithRetryOnConflict(3)); err != nil {
return fmt.Errorf("handleUnenroll update: %w", err)
}
zlog.Info().Msg("ack unenroll")
return nil
}
func (ack *AckT) handleUpgrade(ctx context.Context, zlog zerolog.Logger, agent *model.Agent, event UpgradeEvent) error {
span, ctx := apm.StartSpan(ctx, "ackUpgrade", "process")
defer span.End()
now := time.Now().UTC().Format(time.RFC3339)
doc := bulk.UpdateFields{}
if event.Error != nil {
// if the payload indicates a retry, mark change the upgrade status to retrying.
if event.Payload == nil {
zlog.Info().Msg("marking agent upgrade as failed, agent logs contain failure message")
doc = bulk.UpdateFields{
dl.FieldUpgradeStartedAt: nil,
dl.FieldUpgradeStatus: "failed",
}
} else if event.Payload.Retry {
zlog.Info().Int("retry_attempt", event.Payload.RetryAttempt).Msg("marking agent upgrade as retrying")
doc[dl.FieldUpgradeStatus] = "retrying" // Keep FieldUpgradeStatedAt abd FieldUpgradeded at to original values
} else {
zlog.Info().Int("retry_attempt", event.Payload.RetryAttempt).Msg("marking agent upgrade as failed, agent logs contain failure message")
doc = bulk.UpdateFields{
dl.FieldUpgradeStartedAt: nil,
dl.FieldUpgradeStatus: "failed",
}
}
} else {
doc = bulk.UpdateFields{
dl.FieldUpgradeStartedAt: nil,
dl.FieldUpgradeStatus: nil,
dl.FieldUpgradedAt: now,
}
if agent.UpgradeDetails == nil {
doc[dl.FieldUpgradeAttempts] = nil
}
}
body, err := doc.Marshal()
if err != nil {
return fmt.Errorf("handleUpgrade marshal: %w", err)
}
if err = ack.bulk.Update(ctx, dl.FleetAgents, agent.Id, body, bulk.WithRefresh(), bulk.WithRetryOnConflict(3)); err != nil {
return fmt.Errorf("handleUpgrade update: %w", err)
}
zlog.Info().
Str("lastReportedVersion", agent.Agent.Version).
Str("upgradedAt", now).
Str(logger.AgentID, agent.Agent.ID).
Str(logger.ActionID, event.ActionId).
Msg("ack upgrade")
return nil
}
func isAgentActive(ctx context.Context, zlog zerolog.Logger, bulk bulk.Bulk, agentID string) bool {
agent, err := dl.FindAgent(ctx, bulk, dl.QueryAgentByID, dl.FieldID, agentID)
if err != nil {
zlog.Error().
Err(err).
Msg("failed to find agent by ID")
return true
}
return agent.Active // it is a valid error in case agent is active (was not invalidated)
}
// Generate an update script that validates that the policy_id
// has not changed underneath us by an upstream process (Kibana or otherwise).
// We have a race condition where a user could have assigned a new policy to
// an agent while we were busy updating the old one. A blind update to the
// agent record without a check could set the revision for the wrong
// policy. This script should be coupled with a "retry_on_conflict" parameter
// to allow for *other* changes to the agent record while we running the script.
// (For example, say the background bulk check-in timestamp update task fires)
//
// WARNING: This assumes the input data is sanitized.
const kUpdatePolicyPrefix = `{"script":{"lang":"painless","source":"if (ctx._source.policy_id == params.id) {ctx._source.remove('default_api_key_history');ctx._source.` +
dl.FieldPolicyRevisionIdx +
` = params.rev;ctx._source.` +
dl.FieldUpdatedAt +
` = params.ts;} else {ctx.op = \"noop\";}","params": {"id":"`
func makeUpdatePolicyBody(policyID string, newRev int64) []byte {
var buf bytes.Buffer
buf.Grow(410)
// Not pretty, but fast.
buf.WriteString(kUpdatePolicyPrefix)
buf.WriteString(policyID)
buf.WriteString(`","rev":`)
buf.WriteString(strconv.FormatInt(newRev, 10))
buf.WriteString(`,"ts":"`)
buf.WriteString(time.Now().UTC().Format(time.RFC3339))
buf.WriteString(`"}}}`)
return buf.Bytes()
}
func invalidateAPIKeys(ctx context.Context, zlog zerolog.Logger, bulk bulk.Bulk, toRetireAPIKeyIDs []model.ToRetireAPIKeyIdsItems, skip string) {
ids := make([]string, 0, len(toRetireAPIKeyIDs))
remoteIds := make(map[string][]string)
for _, k := range toRetireAPIKeyIDs {
if k.ID == skip || k.ID == "" {
continue
}
if k.Output != "" {
if remoteIds[k.Output] == nil {
remoteIds[k.Output] = make([]string, 0)
}
remoteIds[k.Output] = append(remoteIds[k.Output], k.ID)
} else {
ids = append(ids, k.ID)
}
}
if len(ids) > 0 {
zlog.Info().Strs("fleet.policy.apiKeyIDsToRetire", ids).Msg("Invalidate old API keys")
if err := bulk.APIKeyInvalidate(ctx, ids...); err != nil {
zlog.Info().Err(err).Strs("ids", ids).Msg("Failed to invalidate API keys")
}
}
// using remote es bulker to invalidate api key
for outputName, outputIds := range remoteIds {
outputBulk := bulk.GetBulker(outputName)
if outputBulk == nil {
// read output config from .fleet-policies, not filtering by policy id as agent could be reassigned
policy, err := dl.QueryOutputFromPolicy(ctx, bulk, outputName)
if err != nil || policy == nil {
zlog.Warn().Str(logger.PolicyOutputName, outputName).Any("ids", outputIds).Msg("Output policy not found, API keys will be orphaned")
} else {
outputBulk, _, err = bulk.CreateAndGetBulker(ctx, zlog, outputName, policy.Data.Outputs)
if err != nil {
zlog.Warn().Str(logger.PolicyOutputName, outputName).Any("ids", outputIds).Msg("Failed to recreate output bulker, API keys will be orphaned")
}
}
}
if outputBulk != nil {
if err := outputBulk.APIKeyInvalidate(ctx, outputIds...); err != nil {
zlog.Info().Err(err).Strs("ids", outputIds).Str(logger.PolicyOutputName, outputName).Msg("Failed to invalidate API keys")
}
}
}
}