scripts/share_ad_job_state/import_model_snapshot.py (264 lines of code) (raw):
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
import argparse
import json
import os
import re
import tarfile
import time
from getpass import getpass
from typing import Any, Dict, List, Optional
# Disable noisy warning about missing certificate verification
import urllib3
from elasticsearch import ApiError, Elasticsearch, TransportError, helpers
from loguru import logger
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
def sanitize_filename(name: str) -> str:
"""Sanitize filenames to prevent directory traversal and other security issues."""
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)
def is_within_directory(directory: str, target: str) -> bool:
"""Check if the target path is within the given directory."""
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
return os.path.commonpath([abs_directory]) == os.path.commonpath(
[abs_directory, abs_target]
)
def safe_extract(tar: tarfile.TarFile, path: str = ".") -> None:
"""Safely extract tar files to prevent path traversal attacks."""
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception(f"Attempted Path Traversal in Tar File: {member.name}")
tar.extractall(path=path)
def extract_archive(archive_path: str, output_dir: str) -> List[str]:
"""
Extracts a tar.gz archive to the specified output directory.
Args:
archive_path (str): The path to the tar.gz archive.
output_dir (str): The directory to extract files to.
Returns:
List[str]: A list of extracted file paths.
"""
logger.info(f"Extracting archive {archive_path} to {output_dir}")
try:
with tarfile.open(archive_path, "r:gz") as tar:
safe_extract(tar, path=output_dir)
extracted_files = tar.getnames()
return [os.path.join(output_dir, file) for file in extracted_files]
except (tarfile.TarError, Exception) as e:
logger.error(f"Error extracting archive: {e}")
return []
def generate_actions(file_path: str, new_index: str):
"""
Generator function that reads a .ndjson file and yields actions.
It yields two lines at a time (action + document).
Args:
file_path (str): Path to the .ndjson file.
new_index (str): The new index name to replace in actions.
Yields:
Dict[str, Any]: Action dictionaries for bulk upload.
"""
try:
with open(file_path, "r", encoding="utf-8") as f:
while True:
action_line = f.readline().strip()
document_line = f.readline().strip()
if not action_line or not document_line:
break
# Replace the index name in the action
action = json.loads(action_line)
document = json.loads(document_line)
yield {
"_op_type": "index",
"_index": new_index,
"_id": action["index"]["_id"],
"_source": document,
}
except IOError as e:
logger.error(f"Error reading file {file_path}: {e}")
def upload_data(es_client: Elasticsearch, index: str, file_path: str) -> None:
"""
Loads JSON data into the specified Elasticsearch index.
Args:
es_client (Elasticsearch): The Elasticsearch client.
index (str): The index to load data into.
file_path (str): Path to the data file.
"""
try:
response = helpers.bulk(
es_client, generate_actions(file_path, index), index=index, chunk_size=1000
)
logger.info(f"{response[0]} documents uploaded to index {index}")
except (ApiError, TransportError, Exception) as e:
logger.error(f"Error uploading data: {e}")
def create_input_index(es_client: Elasticsearch, index: str) -> None:
"""
Creates an Elasticsearch index for input data. If it exists, deletes and recreates it.
Args:
es_client (Elasticsearch): The Elasticsearch client.
index (str): The index to create.
"""
try:
if es_client.indices.exists(index=index):
logger.info(f"Deleting existing index {index}")
es_client.indices.delete(index=index)
logger.info(f"Creating index {index}")
es_client.indices.create(index=index)
except (ApiError, TransportError) as e:
logger.error(f"Error creating index {index}: {e}")
def create_job_config(
es_client: Elasticsearch,
job_config: Dict[str, Any],
new_index: Optional[str] = None,
) -> Optional[str]:
"""
Uploads the job configuration using the put_job API.
Args:
es_client (Elasticsearch): The Elasticsearch client.
job_config (Dict[str, Any]): The job configuration to upload.
new_index (Optional[str]): New index name for the datafeed config.
Returns:
Optional[str]: The job ID if successful, None otherwise.
"""
job_fields = [
"job_id",
"description",
"analysis_config",
"data_description",
"model_snapshot_retention_days",
"results_index_name",
"analysis_limits",
"custom_settings",
"allow_lazy_open",
"datafeed_config",
]
filtered_config = {key: job_config[key] for key in job_fields if key in job_config}
job_id = filtered_config.get("job_id")
if not job_id:
logger.error("Job ID not found in job configuration.")
return None
# Remove sensitive or irrelevant fields
datafeed_config = filtered_config.get("datafeed_config", {})
datafeed_config.pop("authorization", None)
datafeed_config.pop("job_id", None)
if new_index:
datafeed_config["indices"] = [new_index]
filtered_config["datafeed_config"] = datafeed_config
# Check if the job already exists and delete it
try:
if es_client.ml.get_jobs(job_id=job_id):
es_client.ml.delete_job(job_id=job_id, force=True)
logger.info(f"Deleted existing job with ID: {job_id}")
except (ApiError, TransportError):
logger.info(f"Job with ID {job_id} does not exist, proceeding to create.")
try:
response = es_client.ml.put_job(job_id=job_id, body=filtered_config)
logger.info(f"Job configuration uploaded with ID: {response['job_id']}")
return job_id
except (ApiError, TransportError) as e:
logger.error(f"Error uploading job configuration: {e}")
return None
def load_snapshot_stats(es_client: Elasticsearch, file_path: str) -> Optional[str]:
"""
Loads snapshot statistics from the given file and indexes it in Elasticsearch.
Args:
es_client (Elasticsearch): The Elasticsearch client.
file_path (str): The path to the snapshot statistics file.
Returns:
Optional[str]: The snapshot ID if successful, None otherwise.
"""
logger.info(f"Loading snapshot statistics from {file_path}")
try:
with open(file_path, "r", encoding="utf-8") as f:
snapshot_stats = json.load(f)
snapshot_id = snapshot_stats.get("snapshot_id")
if not snapshot_id:
logger.error("Snapshot ID not found in snapshot statistics.")
return None
index_name = ".ml-anomalies-shared"
# Extract ID from file name
id_ = os.path.splitext(os.path.basename(file_path))[0].replace(
"ml-anomalies-snapshot_doc_", ""
)
response = es_client.index(index=index_name, body=snapshot_stats, id=id_)
logger.info(f"Snapshot statistics indexed with ID: {response['_id']}")
return snapshot_id
except (ApiError, TransportError, IOError, json.JSONDecodeError) as e:
logger.error(f"Error indexing snapshot statistics: {e}")
return None
def find_file(file_name: str, extracted_files: List[str]) -> Optional[str]:
"""Find a file in the extracted files list."""
for file in extracted_files:
if file_name in file:
return file
return None
def import_model_state(
job_id: str, es_client: Elasticsearch, archive_path: str
) -> None:
"""
Imports the model state, job configuration, annotations, and input data from an archive to Elasticsearch.
Args:
job_id (str): The ID of the job to import.
es_client (Elasticsearch): The Elasticsearch client.
archive_path (str): The path to the tar.gz archive.
"""
safe_job_id = sanitize_filename(job_id)
output_dir = f"extracted_{safe_job_id}"
extracted_files = extract_archive(archive_path, output_dir)
if not extracted_files:
logger.error("No files extracted. Aborting import.")
return
input_index = None
job_config_file = f"{safe_job_id}_config.json"
snapshot_docs_file = f"{safe_job_id}_snapshot_docs.ndjson"
input_file = f"{safe_job_id}_input.ndjson"
snapshot_stats_file = None
# Find snapshot statistics file
for file in extracted_files:
if "ml-anomalies-snapshot_doc" in file:
snapshot_stats_file = file
break
# Load the input data if available
file = find_file(input_file, extracted_files)
if file:
logger.info(f"Importing input data from {file}")
input_index = f"{safe_job_id}-input"
create_input_index(es_client, input_index)
upload_data(es_client, input_index, file)
else:
logger.warning(f"Input data file {input_file} not found in the archive.")
# Load the job configuration
file = find_file(job_config_file, extracted_files)
if file:
logger.info(f"Importing job configuration from {file}")
try:
with open(file, "r", encoding="utf-8") as f:
job_config = json.load(f)
job_id = create_job_config(es_client, job_config, input_index)
if not job_id:
logger.error("Failed to create job configuration.")
return
except (IOError, json.JSONDecodeError) as e:
logger.error(f"Error reading job configuration: {e}")
return
else:
logger.error(
f"Job configuration file {job_config_file} not found in the archive."
)
return
# Load the job snapshot docs
file = find_file(snapshot_docs_file, extracted_files)
if file:
logger.info(f"Importing snapshots from {file}")
upload_data(es_client, ".ml-state-write", file)
else:
logger.error(
f"Snapshot documents file {snapshot_docs_file} not found in the archive."
)
return
# Load the snapshot stats
if snapshot_stats_file:
logger.info(f"Importing snapshot statistics from {snapshot_stats_file}")
snapshot_id = load_snapshot_stats(es_client, snapshot_stats_file)
if not snapshot_id:
logger.error("Failed to load snapshot statistics.")
return
else:
logger.error("Snapshot statistics file not found in the archive.")
return
# Revert the job to the snapshot
time.sleep(2) # Wait for the snapshot to be indexed
try:
es_client.ml.revert_model_snapshot(job_id=job_id, snapshot_id=snapshot_id)
logger.info(f"Reverted job {job_id} to snapshot {snapshot_id}")
except (ApiError, TransportError) as e:
logger.error(f"Error reverting job to snapshot: {e}")
finally:
# Clean up extracted files and directory
try:
for file in extracted_files:
os.remove(file)
os.rmdir(output_dir)
except OSError as e:
logger.warning(f"Error cleaning up extracted files: {e}")
logger.info(f"Import of job {job_id} completed successfully.")
def main() -> None:
"""Main function to import model state to Elasticsearch."""
parser = argparse.ArgumentParser(description="Import model state to Elasticsearch.")
parser.add_argument(
"--url", type=str, default="https://localhost:9200", help="Elasticsearch URL"
)
parser.add_argument(
"--username",
type=str,
required=True,
help="Username for Elasticsearch authentication",
)
parser.add_argument(
"--password", type=str, help="Password for Elasticsearch authentication"
)
parser.add_argument("--job_id", type=str, required=True, help="Job ID to import")
parser.add_argument(
"--archive_path", type=str, required=True, help="Path to the archive file"
)
parser.add_argument("--cloud_id", type=str, help="Cloud ID for Elasticsearch")
parser.add_argument(
"--ignore_certs",
action="store_true",
help="Ignore SSL certificate verification",
)
args = parser.parse_args()
# Handle password securely
if not args.password:
args.password = getpass(prompt="Enter Elasticsearch password: ")
# Validate archive_path
if not os.path.isfile(args.archive_path):
logger.error(f"Archive file {args.archive_path} does not exist.")
return
logger.info("Connecting to Elasticsearch")
try:
if args.cloud_id:
logger.info("Connecting to Elasticsearch cloud using cloud_id")
es_client = Elasticsearch(
cloud_id=args.cloud_id,
basic_auth=(args.username, args.password),
verify_certs=(not args.ignore_certs),
)
else:
logger.info("Connecting to Elasticsearch using URL")
es_client = Elasticsearch(
[args.url],
basic_auth=(args.username, args.password),
verify_certs=(not args.ignore_certs),
)
except (ApiError, TransportError) as e:
logger.error(f"Failed to connect to Elasticsearch: {e}")
return
import_model_state(args.job_id, es_client, args.archive_path)
if __name__ == "__main__":
main()