scripts/share_ad_job_state/export_model_snapshot.py (436 lines of code) (raw):
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
import argparse
import json
import os
import re
import tarfile
from datetime import datetime
from getpass import getpass
from typing import Any, Dict, List, Optional, Set, Tuple
import urllib3
from elasticsearch import ApiError, Elasticsearch, TransportError, helpers
from loguru import logger
from tqdm import tqdm
# Disable noisy warning about missing certificate verification
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# Constants
KNOWN_OPERATORS = {
"query",
"bool",
"must",
"should",
"filter",
"must_not",
"match",
"term",
"terms",
"range",
"exists",
"missing",
"wildcard",
"regexp",
"fuzzy",
"prefix",
"multi_match",
"match_phrase",
"match_phrase_prefix",
"simple_query_string",
"common",
"ids",
"constant_score",
"dis_max",
"function_score",
"nested",
"has_child",
"has_parent",
"more_like_this",
"script",
"percolate",
"geo_shape",
"geo_bounding_box",
"geo_distance",
"geo_polygon",
"shape",
"parent_id",
"boosting",
"indices",
"span_term",
"span_multi",
"span_first",
"span_near",
"span_or",
"span_not",
"span_containing",
"span_within",
"span_field_masking",
}
def sanitize_filename(name: str) -> str:
"""Sanitize the filename to prevent directory traversal and other security issues."""
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)
def validate_date(date_text: str) -> datetime:
"""Validate and parse date string in the format YYYY-MM-DDTHH:MM:SS."""
try:
return datetime.strptime(date_text, "%Y-%m-%dT%H:%M:%S")
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Invalid date format: '{date_text}'. Expected format: YYYY-MM-DDTHH:MM:SS"
) from e
def save_snapshots(
job_id: str, snapshot_id: str, es_client: Elasticsearch, snapshot_doc_count: int
) -> Optional[str]:
"""
Extract the model state from Elasticsearch based on the given parameters.
Args:
job_id (str): The ID of the job.
snapshot_id (str): The ID of the snapshot.
es_client (Elasticsearch): The Elasticsearch client.
snapshot_doc_count (int): The number of snapshot documents to extract.
Returns:
Optional[str]: The filename of the saved snapshot documents, or None if failed.
"""
index_pattern = ".ml-state-*"
safe_job_id = sanitize_filename(job_id)
filename = f"{safe_job_id}_snapshot_docs.ndjson"
logger.info(f"Writing the compressed model state to {filename}")
all_ids = [
f"{job_id}_model_state_{snapshot_id}#{i + 1}" for i in range(snapshot_doc_count)
]
num_docs = 0
try:
with open(filename, "w", encoding="utf-8") as f:
for doc in tqdm(
helpers.scan(
es_client,
index=index_pattern,
query={"query": {"terms": {"_id": all_ids}}},
size=1000,
),
total=snapshot_doc_count,
desc="Saving snapshots",
):
action = {"index": {"_index": doc["_index"], "_id": doc["_id"]}}
f.write(json.dumps(action) + "\n")
f.write(json.dumps(doc["_source"]) + "\n")
num_docs += 1
logger.info(
f"{num_docs} snapshot documents for job {job_id} stored in {filename}"
)
return filename
except (ApiError, TransportError, IOError) as e:
logger.error(f"Failed to save snapshots: {e}")
return None
def save_snapshot_stats(
job_id: str, snapshot_id: str, es_client: Elasticsearch
) -> Optional[str]:
"""
Retrieves the total number of snapshot documents for the given job and snapshot ID.
Args:
job_id (str): The ID of the job.
snapshot_id (str): The ID of the snapshot.
es_client (Elasticsearch): The Elasticsearch client.
Returns:
Optional[str]: The filename containing the snapshot statistics, or None if failed.
"""
index = ".ml-anomalies-shared"
search_query = {
"query": {
"bool": {
"must": [
{"term": {"job_id": job_id}},
{"term": {"snapshot_id": snapshot_id}},
]
}
}
}
try:
response = es_client.search(index=index, body=search_query)
hits = response.get("hits", {}).get("hits", [])
if hits:
result_doc = hits[0]["_source"]
id_ = sanitize_filename(hits[0]["_id"])
file_name = f"ml-anomalies-snapshot_doc_{id_}.json"
with open(file_name, "w", encoding="utf-8") as f:
json.dump(result_doc, f, indent=4)
logger.info(f"Snapshot document count stored in {file_name}")
return file_name
else:
logger.error(
"No snapshot document found for the given job_id and snapshot_id."
)
return None
except (ApiError, TransportError) as e:
logger.error(f"Error retrieving snapshot stats: {e}")
return None
def save_job_config(
job_id: str, es_client: Elasticsearch
) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
"""
Retrieves the job configuration using the Elasticsearch anomaly detection job API.
Args:
job_id (str): The ID of the job.
es_client (Elasticsearch): The Elasticsearch client.
Returns:
Tuple[Optional[str], Optional[Dict[str, Any]]]: The filename containing the job
configuration and the job configuration dictionary, or (None, None) if failed.
"""
try:
response = es_client.ml.get_jobs(job_id=job_id)
if response.get("count", 0) > 0:
config = response["jobs"][0]
safe_job_id = sanitize_filename(job_id)
file_name = f"{safe_job_id}_config.json"
with open(file_name, "w", encoding="utf-8") as f:
json.dump(config, f, indent=4)
logger.info(f"Job configuration for job {job_id} stored in {file_name}")
return file_name, config
else:
logger.error(f"No job configuration found for job_id: {job_id}")
return None, None
except (ApiError, TransportError) as e:
logger.error(
f"Error retrieving job configuration for job_id: {job_id}. Error: {e}"
)
return None, None
def save_annotations(
job_id: str,
before_date: Optional[datetime],
after_date: Optional[datetime],
es_client: Elasticsearch,
) -> Optional[str]:
"""
Retrieves annotations for the given job within the specified date range.
Args:
job_id (str): The ID of the job.
before_date (Optional[datetime]): The upper bound for the create_time.
after_date (Optional[datetime]): The lower bound for the create_time.
es_client (Elasticsearch): The Elasticsearch client.
Returns:
Optional[str]: The filename containing the annotations, or None if failed.
"""
index = ".ml-annotations-read"
date_range = {}
if before_date:
date_range["lte"] = before_date.isoformat()
if after_date:
date_range["gte"] = after_date.isoformat()
search_query = {
"query": {
"bool": {
"must": [
{"term": {"job_id": job_id}},
{"range": {"create_time": date_range}},
]
}
},
"size": 10000,
}
try:
response = es_client.search(index=index, body=search_query)
annotations = response.get("hits", {}).get("hits", [])
safe_job_id = sanitize_filename(job_id)
filename = f"{safe_job_id}_annotations.ndjson"
with open(filename, "w", encoding="utf-8") as f:
for annotation in annotations:
f.write(json.dumps(annotation["_source"]) + "\n")
logger.info(f"Annotations for job {job_id} stored in {filename}")
return filename
except (ApiError, TransportError, IOError) as e:
logger.error(f"Failed to save annotations: {e}")
return None
def extract_field_names_from_json(
query_json: Dict[str, Any], known_operators: Set[str]
) -> Set[str]:
"""Recursively extract field names from a JSON query."""
field_names = set()
def recurse(obj: Any) -> None:
if isinstance(obj, dict):
for key, value in obj.items():
if key not in known_operators:
field_names.add(key)
recurse(value)
elif isinstance(obj, list):
for item in obj:
recurse(item)
recurse(query_json)
return field_names
def extract_possible_field_names(query: Dict[str, Any]) -> Set[str]:
"""Extract possible field names from the Elasticsearch query."""
field_names = extract_field_names_from_json(query, KNOWN_OPERATORS)
# Drop the '.keyword' suffix
field_names = {field.split(".keyword")[0] for field in field_names}
logger.info(f"Extracted field names: {field_names}")
return field_names
def save_inputs(
job_config: Dict[str, Any],
before_date: Optional[datetime],
after_date: Optional[datetime],
es_client: Elasticsearch,
) -> Optional[str]:
"""
Extracts input data based on the job configuration and date range.
Args:
job_config (Dict[str, Any]): The job configuration dictionary.
before_date (Optional[datetime]): The upper bound for the time range.
after_date (Optional[datetime]): The lower bound for the time range.
es_client (Elasticsearch): The Elasticsearch client.
Returns:
Optional[str]: The filename containing the input data, or None if failed.
"""
indices = job_config["datafeed_config"]["indices"]
job_id = job_config["job_id"]
time_field = job_config["data_description"]["time_field"]
query = job_config["datafeed_config"].get("query", {"match_all": {}})
# Extract fields from detectors
field_keys = [
"field_name",
"partition_field_name",
"categorization_field_name",
"by_field_name",
"over_field_name",
"summary_count_field_name",
]
fields = {
detector[key]
for detector in job_config["analysis_config"]["detectors"]
for key in field_keys
if key in detector
}
fields.update(job_config["analysis_config"].get("influencers", []))
fields.add(time_field)
fields.update(extract_possible_field_names(query))
# Remove any '.keyword' suffixes
fields = {field.split(".keyword")[0] for field in fields}
date_range = {}
if before_date:
date_range["lte"] = before_date.isoformat()
if after_date:
date_range["gte"] = after_date.isoformat()
search_query = {
"_source": list(fields),
"query": {
"bool": {
"must": [
query,
{"range": {time_field: date_range}},
]
}
},
"size": 1000,
}
safe_job_id = sanitize_filename(job_id)
filename = f"{safe_job_id}_input.ndjson"
num_docs = 0
try:
with open(filename, "w", encoding="utf-8") as f:
for doc in tqdm(
helpers.scan(es_client, index=indices, query=search_query),
desc="Saving input data",
):
action = {"index": {"_index": doc["_index"], "_id": doc["_id"]}}
f.write(json.dumps(action) + "\n")
f.write(json.dumps(doc["_source"]) + "\n")
num_docs += 1
logger.info(f"{num_docs} input documents for job stored in {filename}")
return filename
except (ApiError, TransportError, IOError) as e:
logger.error(f"Failed to save input data: {e}")
return None
def create_archive(job_id: str, files: List[Optional[str]]) -> None:
"""
Creates a tar.gz archive containing the specified files.
Args:
job_id (str): The ID of the job.
files (List[Optional[str]]): List of file paths to include in the archive.
"""
safe_job_id = sanitize_filename(job_id)
archive_name = f"{safe_job_id}_state.tar.gz"
try:
with tarfile.open(archive_name, "w:gz") as tar:
for file in files:
if file and os.path.exists(file):
tar.add(file, arcname=os.path.basename(file))
logger.info(f"Added {file} to archive {archive_name}")
else:
logger.warning(f"File {file} not found, skipping.")
logger.info(f"Archive {archive_name} created successfully.")
except IOError as e:
logger.error(f"Failed to create archive: {e}")
finally:
# Remove the archived files
logger.info("Removing temporary files")
for file in files:
if file and os.path.exists(file):
os.remove(file)
def get_snapshot_info(
es_client: Elasticsearch, job_id: str, before_date: Optional[datetime] = None
) -> Optional[Tuple[str, int]]:
"""
Retrieves the latest snapshot information for a given job.
Args:
es_client (Elasticsearch): The Elasticsearch client.
job_id (str): The ID of the job.
before_date (Optional[datetime]): The date before which to retrieve the snapshot.
Returns:
Optional[Tuple[str, int]]: The snapshot ID and snapshot document count, or None if failed.
"""
try:
snapshot_response = es_client.ml.get_model_snapshots(
job_id=job_id,
end=before_date.isoformat() if before_date else None,
desc=True,
)
if snapshot_response.get("count", 0) > 0:
latest_snapshot = snapshot_response["model_snapshots"][0]
snapshot_id = latest_snapshot["snapshot_id"]
snapshot_doc_count = latest_snapshot["snapshot_doc_count"]
logger.info(f"Latest snapshot ID for job {job_id}: {snapshot_id}")
return snapshot_id, snapshot_doc_count
else:
logger.error("No snapshots found before the given date.")
return None
except (ApiError, TransportError) as e:
logger.error(f"Error retrieving snapshot info: {e}")
return None
def main() -> None:
"""
Main function to extract model state from Elasticsearch.
Example usage:
python export_model_snapshot.py --url https://localhost:9200 --username user --job_id <job_id> --before_date 2023-05-10T00:00:00 --after_date 2023-01-01T00:00:00 --include_inputs
Output:
<job_id>_state.tar.gz archive with the following files:
- ml-anomalies-snapshot_doc_<id>.json snapshot stats document
- <job_id>_config.json job configuration
- <job_id>_snapshot_docs.ndjson snapshot documents
- <job_id>_annotations.ndjson annotations
- <job_id>_input.ndjson input data (if --include_inputs flag is set)
"""
parser = argparse.ArgumentParser(
description=(
"Extract model state from Elasticsearch. WARNING: This operation will extract data that may include PII."
)
)
parser.add_argument(
"--url",
type=str,
required=False,
default="https://localhost:9200",
help="Elasticsearch URL",
)
parser.add_argument(
"--username",
type=str,
required=True,
help="Username for Elasticsearch authentication",
)
parser.add_argument(
"--password",
type=str,
required=False,
help="Password for Elasticsearch authentication",
)
parser.add_argument(
"--job_id", type=str, required=True, help="Job ID to extract model state"
)
parser.add_argument(
"--cloud_id", type=str, required=False, help="Cloud ID for Elasticsearch"
)
parser.add_argument(
"--before_date",
type=validate_date,
required=False,
help="Search for the latest snapshot CREATED before the given date (format: YYYY-MM-DDTHH:MM:SS)",
)
parser.add_argument(
"--after_date",
type=validate_date,
required=False,
help="Search for input data and annotations after the specified date (format: YYYY-MM-DDTHH:MM:SS)",
)
parser.add_argument(
"--include_inputs",
action="store_true",
help="Include input data in the archive",
)
parser.add_argument(
"--ignore_certs",
action="store_true",
help="Disable SSL certificate verification",
)
args = parser.parse_args()
# Handle password securely
if not args.password:
args.password = getpass(prompt="Enter Elasticsearch password: ")
# Warn about PII data
logger.warning("This operation will extract data that may include PII.")
confirm = input("Do you wish to continue? (yes/no): ")
if confirm.lower() != "yes":
logger.info("Operation aborted by the user.")
return
logger.info("Connecting to Elasticsearch")
# Connect to an Elasticsearch instance
try:
if args.cloud_id:
logger.info("Connecting to Elasticsearch cloud using cloud_id")
es_client = Elasticsearch(
cloud_id=args.cloud_id,
basic_auth=(args.username, args.password),
verify_certs=(not args.ignore_certs),
)
else:
logger.info("Connecting to Elasticsearch using URL")
es_client = Elasticsearch(
[args.url],
basic_auth=(args.username, args.password),
verify_certs=(not args.ignore_certs),
)
except (ApiError, TransportError) as e:
logger.error(f"Failed to connect to Elasticsearch: {e}")
return
snapshot_info = get_snapshot_info(es_client, args.job_id, args.before_date)
if snapshot_info is None:
logger.error("Failed to retrieve snapshot info.")
return
snapshot_id, snapshot_doc_count = snapshot_info
# Get the snapshot document count and store the result
file_name_ml_anomalies = save_snapshot_stats(args.job_id, snapshot_id, es_client)
# Get the job configuration and store it
job_config_result = save_job_config(args.job_id, es_client)
if job_config_result:
file_name_job_config, job_configuration = job_config_result
else:
logger.error("Failed to retrieve job configuration.")
return
# Get the annotations and store them
file_name_annotations = save_annotations(
args.job_id, args.before_date, args.after_date, es_client
)
# Get the input data and store it
if args.include_inputs:
file_name_inputs = (
save_inputs(job_configuration, args.before_date, args.after_date, es_client)
if job_configuration
else None
)
else:
logger.info("Input data will not be included in the archive.")
file_name_inputs = None
# Call the function to extract model state
filename_snapshots = save_snapshots(
args.job_id, snapshot_id, es_client, snapshot_doc_count
)
# Create an archive with all generated files
files_to_archive = [
file_name_ml_anomalies,
file_name_job_config,
filename_snapshots,
file_name_annotations,
file_name_inputs,
]
create_archive(args.job_id, files_to_archive)
if __name__ == "__main__":
main()