shippers/es.py (260 lines of code) (raw):
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License 2.0;
# you may not use this file except in compliance with the Elastic License 2.0.
import datetime
import http
import uuid
from typing import Any, Dict, Optional, Union
import elasticapm # noqa: F401
from elasticsearch import Elasticsearch
from elasticsearch.exceptions import SerializationError
from elasticsearch.helpers import bulk as es_bulk
from elasticsearch.serializer import Serializer
import share.utils
from share import json_dumper, json_parser, normalise_event, shared_logger
from share.environment import get_environment
from share.version import version
from .shipper import EventIdGeneratorCallable, ReplayHandlerCallable
_EVENT_BUFFERED = "_EVENT_BUFFERED"
_EVENT_SENT = "_EVENT_SENT"
# List of HTTP status codes that are considered retryable
_retryable_http_status_codes = [
http.HTTPStatus.TOO_MANY_REQUESTS,
]
class JSONSerializer(Serializer):
mimetype = "application/json"
def loads(self, s: str) -> Any:
try:
return json_parser(s)
except Exception as e:
raise SerializationError(s, e)
def dumps(self, data: Any) -> str:
if isinstance(data, str):
return data
if isinstance(data, bytes):
return data.decode("utf-8")
try:
return json_dumper(data)
except Exception as e:
raise SerializationError(data, e)
class ElasticsearchShipper:
"""
Elasticsearch Shipper.
This class implements concrete Elasticsearch Shipper
"""
def __init__(
self,
elasticsearch_url: str = "",
username: str = "",
password: str = "",
cloud_id: str = "",
api_key: str = "",
es_datastream_name: str = "",
es_dead_letter_index: str = "",
tags: list[str] = [],
batch_max_actions: int = 500,
batch_max_bytes: int = 10 * 1024 * 1024,
ssl_assert_fingerprint: str = "",
):
self._bulk_actions: list[dict[str, Any]] = []
self._bulk_batch_size = batch_max_actions
self._bulk_kwargs: dict[str, Any] = {
"max_retries": 4,
"stats_only": False,
"raise_on_error": False,
"raise_on_exception": False,
"max_chunk_bytes": batch_max_bytes,
}
if batch_max_actions > 0:
self._bulk_kwargs["chunk_size"] = batch_max_actions
es_client_kwargs: dict[str, Any] = {}
if elasticsearch_url:
es_client_kwargs["hosts"] = [elasticsearch_url]
self._output_destination = elasticsearch_url
elif cloud_id:
es_client_kwargs["cloud_id"] = cloud_id
self._output_destination = cloud_id
else:
raise ValueError("You must provide one between elasticsearch_url or cloud_id")
if username:
es_client_kwargs["http_auth"] = (username, password)
elif api_key:
es_client_kwargs["api_key"] = api_key
else:
raise ValueError("You must provide one between username and password or api_key")
if ssl_assert_fingerprint:
es_client_kwargs["verify_certs"] = False
es_client_kwargs["ssl_assert_fingerprint"] = ssl_assert_fingerprint
es_client_kwargs["serializer"] = JSONSerializer()
self._replay_args: dict[str, Any] = {}
self._es_client = self._elasticsearch_client(**es_client_kwargs)
self._replay_handler: Optional[ReplayHandlerCallable] = None
self._event_id_generator: Optional[EventIdGeneratorCallable] = None
self._es_datastream_name = es_datastream_name
self._es_dead_letter_index = es_dead_letter_index
self._tags = tags
self._es_index = ""
self._dataset = ""
self._namespace = ""
@staticmethod
def _elasticsearch_client(**es_client_kwargs: Any) -> Elasticsearch:
"""
Getter for elasticsearch client
Extracted for mocking
"""
es_client_kwargs["timeout"] = 30
es_client_kwargs["max_retries"] = 4
es_client_kwargs["http_compress"] = True
es_client_kwargs["retry_on_timeout"] = True
es_client_kwargs["headers"] = {
"User-Agent": share.utils.create_user_agent(esf_version=version, environment=get_environment())
}
return Elasticsearch(**es_client_kwargs)
def _enrich_event(self, event_payload: dict[str, Any]) -> None:
"""
This method enrich with default metadata the ES event payload.
Currently, hardcoded for logs type
"""
if "fields" not in event_payload:
return
event_payload["tags"] = ["forwarded"]
if self._dataset != "":
event_payload["data_stream"] = {
"type": "logs",
"dataset": self._dataset,
"namespace": self._namespace,
}
event_payload["event"] = {"dataset": self._dataset}
event_payload["tags"] += [self._dataset.replace(".", "-")]
event_payload["tags"] += self._tags
def _handle_outcome(self, actions: list[dict[str, Any]], errors: tuple[int, Union[int, list[Any]]]) -> list[Any]:
assert isinstance(errors[1], list)
success = errors[0]
failed: list[Any] = []
for error in errors[1]:
action_failed = [action for action in actions if action["_id"] == error["create"]["_id"]]
# an ingestion pipeline might override the _id, we can only skip in this case
if len(action_failed) != 1:
continue
shared_logger.warning(
"elasticsearch shipper", extra={"error": error["create"]["error"], "_id": error["create"]["_id"]}
)
if "status" in error["create"] and error["create"]["status"] == http.HTTPStatus.CONFLICT:
# Skip duplicate events on dead letter index and replay queue
continue
failed_error = {"action": action_failed[0]} | self._parse_error(error["create"])
failed.append(failed_error)
if len(failed) > 0:
shared_logger.warning("elasticsearch shipper", extra={"success": success, "failed": len(failed)})
else:
shared_logger.info("elasticsearch shipper", extra={"success": success, "failed": len(failed)})
return failed
def _parse_error(self, error: dict[str, Any]) -> dict[str, Any]:
"""
Parses the error response from Elasticsearch and returns a
standardised error field.
The error field is a dictionary with the following keys:
- `message`: The error message
- `type`: The error type
If the error is not recognised, the `message` key is set
to "Unknown error".
It also sets the status code in the http field if it is present
as a number in the response.
"""
field: dict[str, Any] = {"error": {"message": "Unknown error", "type": "unknown"}}
if "status" in error and isinstance(error["status"], int):
# Collecting the HTTP response status code in the
# error field, if present, and the type is an integer.
#
# Sometimes the status code is a string, for example,
# when the connection to the server fails.
field["http"] = {"response": {"status_code": error["status"]}}
if "error" not in error:
return field
if isinstance(error["error"], str):
# Can happen with connection errors.
field["error"]["message"] = error["error"]
if "exception" in error:
# The exception field is usually an Exception object,
# so we convert it to a string.
field["error"]["type"] = str(type(error["exception"]))
elif isinstance(error["error"], dict):
# Can happen with status 5xx errors.
# In this case, we look for the "reason" and "type" fields.
if "reason" in error["error"]:
field["error"]["message"] = error["error"]["reason"]
if "type" in error["error"]:
field["error"]["type"] = error["error"]["type"]
return field
def set_event_id_generator(self, event_id_generator: EventIdGeneratorCallable) -> None:
self._event_id_generator = event_id_generator
def set_replay_handler(self, replay_handler: ReplayHandlerCallable) -> None:
self._replay_handler = replay_handler
def send(self, event: dict[str, Any]) -> str:
self._replay_args["es_datastream_name"] = self._es_datastream_name
if not hasattr(self, "_es_index") or self._es_index == "":
self._discover_dataset(event_payload=event)
self._enrich_event(event_payload=event)
event["_op_type"] = "create"
if "_index" not in event:
event["_index"] = self._es_index
if "_id" not in event and self._event_id_generator is not None:
event["_id"] = self._event_id_generator(event)
event = normalise_event(event_payload=event)
self._bulk_actions.append(event)
if len(self._bulk_actions) < self._bulk_batch_size:
return _EVENT_BUFFERED
self.flush()
return _EVENT_SENT
def flush(self) -> None:
if len(self._bulk_actions) == 0:
return
errors = es_bulk(self._es_client, self._bulk_actions, **self._bulk_kwargs)
failed = self._handle_outcome(actions=self._bulk_actions, errors=errors)
# Send failed requests to dead letter index, if enabled
if len(failed) > 0 and self._es_dead_letter_index:
failed = self._send_dead_letter_index(failed)
# Send remaining failed requests to replay queue, if enabled
if isinstance(failed, list) and len(failed) > 0 and self._replay_handler is not None:
for outcome in failed:
if "action" not in outcome:
shared_logger.error("action could not be extracted to be replayed", extra={"outcome": outcome})
continue
self._replay_handler(self._output_destination, self._replay_args, outcome["action"])
self._bulk_actions = []
return
def _send_dead_letter_index(self, actions: list[Any]) -> list[Any]:
"""
Index the failed actions in the dead letter index (DLI).
This function attempts to index failed actions to the DLI, but may not do so
for one of the following reasons:
1. The failed action could not be encoded for indexing in the DLI.
2. ES returned an error on the attempt to index the failed action in the DLI.
3. The failed action error is retryable (connection error or status code 429).
Retryable errors are not indexed in the DLI, as they are expected to be
sent again to the data stream at `es_datastream_name` by the replay handler.
Args:
actions (list[Any]): A list of actions to index in the DLI.
Returns:
list[Any]: A list of actions that were not indexed in the DLI due to one of
the reasons mentioned above.
"""
non_indexed_actions: list[Any] = []
encoded_actions = []
for action in actions:
if (
"http" not in action # no http status: connection error
or action["http"]["response"]["status_code"] in _retryable_http_status_codes
):
# We don't want to forward this action to
# the dead letter index.
#
# Add the action to the list of non-indexed
# actions and continue with the next one.
non_indexed_actions.append(action)
continue
# Reshape event to dead letter index
encoded = self._encode_dead_letter(action)
if not encoded:
shared_logger.error("cannot encode dead letter index event from payload", extra={"action": action})
non_indexed_actions.append(action)
encoded_actions.append(encoded)
# If no action can be encoded, return original action list as failed
if len(encoded_actions) == 0:
return non_indexed_actions
errors = es_bulk(self._es_client, encoded_actions, **self._bulk_kwargs)
failed = self._handle_outcome(actions=encoded_actions, errors=errors)
if not isinstance(failed, list) or len(failed) == 0:
return non_indexed_actions
for action in failed:
event_payload = self._decode_dead_letter(action)
if not event_payload:
shared_logger.error("cannot decode dead letter index event from payload", extra={"action": action})
continue
non_indexed_actions.append(event_payload)
return non_indexed_actions
def _encode_dead_letter(self, outcome: dict[str, Any]) -> dict[str, Any]:
if "action" not in outcome or "error" not in outcome:
return {}
# Assign random id in case bulk() results in error, it can be matched to the original
# action
encoded = {
"@timestamp": datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
"_id": str(uuid.uuid4()),
"_index": self._es_dead_letter_index,
"_op_type": "create",
"message": json_dumper(outcome["action"]),
"error": outcome["error"],
}
if "http" in outcome:
# the `http.response.status_code` is not
# always present in the error field.
encoded["http"] = outcome["http"]
return encoded
def _decode_dead_letter(self, dead_letter_outcome: dict[str, Any]) -> dict[str, Any]:
if "action" not in dead_letter_outcome or "message" not in dead_letter_outcome["action"]:
return {}
return {"action": json_parser(dead_letter_outcome["action"]["message"])}
def _discover_dataset(self, event_payload: Dict[str, Any]) -> None:
if self._es_datastream_name != "":
if self._es_datastream_name.startswith("logs-"):
datastream_components = self._es_datastream_name.split("-")
if len(datastream_components) == 3:
self._dataset = datastream_components[1]
self._namespace = datastream_components[2]
else:
shared_logger.debug(
"es_datastream_name not matching logs datastream pattern, no dataset and namespace set"
)
else:
shared_logger.debug(
"es_datastream_name not matching logs datastream pattern, no dataset and namespace set"
)
self._es_index = self._es_datastream_name
return
else:
self._namespace = "default"
if "meta" not in event_payload or "integration_scope" not in event_payload["meta"]:
self._dataset = "generic"
else:
self._dataset = event_payload["meta"]["integration_scope"]
if self._dataset == "aws.cloudtrail-digest":
self._dataset = "aws.cloudtrail"
if self._dataset == "generic":
shared_logger.debug("dataset set to generic")
shared_logger.debug("dataset", extra={"dataset": self._dataset})
self._es_index = f"logs-{self._dataset}-{self._namespace}"
self._es_datastream_name = self._es_index