scripts/share_ad_job_state/import_model_snapshot.py (264 lines of code) (raw):

# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one # or more contributor license agreements. Licensed under the Elastic License # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. import argparse import json import os import re import tarfile import time from getpass import getpass from typing import Any, Dict, List, Optional # Disable noisy warning about missing certificate verification import urllib3 from elasticsearch import ApiError, Elasticsearch, TransportError, helpers from loguru import logger urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) def sanitize_filename(name: str) -> str: """Sanitize filenames to prevent directory traversal and other security issues.""" return re.sub(r"[^a-zA-Z0-9_-]", "_", name) def is_within_directory(directory: str, target: str) -> bool: """Check if the target path is within the given directory.""" abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) return os.path.commonpath([abs_directory]) == os.path.commonpath( [abs_directory, abs_target] ) def safe_extract(tar: tarfile.TarFile, path: str = ".") -> None: """Safely extract tar files to prevent path traversal attacks.""" for member in tar.getmembers(): member_path = os.path.join(path, member.name) if not is_within_directory(path, member_path): raise Exception(f"Attempted Path Traversal in Tar File: {member.name}") tar.extractall(path=path) def extract_archive(archive_path: str, output_dir: str) -> List[str]: """ Extracts a tar.gz archive to the specified output directory. Args: archive_path (str): The path to the tar.gz archive. output_dir (str): The directory to extract files to. Returns: List[str]: A list of extracted file paths. """ logger.info(f"Extracting archive {archive_path} to {output_dir}") try: with tarfile.open(archive_path, "r:gz") as tar: safe_extract(tar, path=output_dir) extracted_files = tar.getnames() return [os.path.join(output_dir, file) for file in extracted_files] except (tarfile.TarError, Exception) as e: logger.error(f"Error extracting archive: {e}") return [] def generate_actions(file_path: str, new_index: str): """ Generator function that reads a .ndjson file and yields actions. It yields two lines at a time (action + document). Args: file_path (str): Path to the .ndjson file. new_index (str): The new index name to replace in actions. Yields: Dict[str, Any]: Action dictionaries for bulk upload. """ try: with open(file_path, "r", encoding="utf-8") as f: while True: action_line = f.readline().strip() document_line = f.readline().strip() if not action_line or not document_line: break # Replace the index name in the action action = json.loads(action_line) document = json.loads(document_line) yield { "_op_type": "index", "_index": new_index, "_id": action["index"]["_id"], "_source": document, } except IOError as e: logger.error(f"Error reading file {file_path}: {e}") def upload_data(es_client: Elasticsearch, index: str, file_path: str) -> None: """ Loads JSON data into the specified Elasticsearch index. Args: es_client (Elasticsearch): The Elasticsearch client. index (str): The index to load data into. file_path (str): Path to the data file. """ try: response = helpers.bulk( es_client, generate_actions(file_path, index), index=index, chunk_size=1000 ) logger.info(f"{response[0]} documents uploaded to index {index}") except (ApiError, TransportError, Exception) as e: logger.error(f"Error uploading data: {e}") def create_input_index(es_client: Elasticsearch, index: str) -> None: """ Creates an Elasticsearch index for input data. If it exists, deletes and recreates it. Args: es_client (Elasticsearch): The Elasticsearch client. index (str): The index to create. """ try: if es_client.indices.exists(index=index): logger.info(f"Deleting existing index {index}") es_client.indices.delete(index=index) logger.info(f"Creating index {index}") es_client.indices.create(index=index) except (ApiError, TransportError) as e: logger.error(f"Error creating index {index}: {e}") def create_job_config( es_client: Elasticsearch, job_config: Dict[str, Any], new_index: Optional[str] = None, ) -> Optional[str]: """ Uploads the job configuration using the put_job API. Args: es_client (Elasticsearch): The Elasticsearch client. job_config (Dict[str, Any]): The job configuration to upload. new_index (Optional[str]): New index name for the datafeed config. Returns: Optional[str]: The job ID if successful, None otherwise. """ job_fields = [ "job_id", "description", "analysis_config", "data_description", "model_snapshot_retention_days", "results_index_name", "analysis_limits", "custom_settings", "allow_lazy_open", "datafeed_config", ] filtered_config = {key: job_config[key] for key in job_fields if key in job_config} job_id = filtered_config.get("job_id") if not job_id: logger.error("Job ID not found in job configuration.") return None # Remove sensitive or irrelevant fields datafeed_config = filtered_config.get("datafeed_config", {}) datafeed_config.pop("authorization", None) datafeed_config.pop("job_id", None) if new_index: datafeed_config["indices"] = [new_index] filtered_config["datafeed_config"] = datafeed_config # Check if the job already exists and delete it try: if es_client.ml.get_jobs(job_id=job_id): es_client.ml.delete_job(job_id=job_id, force=True) logger.info(f"Deleted existing job with ID: {job_id}") except (ApiError, TransportError): logger.info(f"Job with ID {job_id} does not exist, proceeding to create.") try: response = es_client.ml.put_job(job_id=job_id, body=filtered_config) logger.info(f"Job configuration uploaded with ID: {response['job_id']}") return job_id except (ApiError, TransportError) as e: logger.error(f"Error uploading job configuration: {e}") return None def load_snapshot_stats(es_client: Elasticsearch, file_path: str) -> Optional[str]: """ Loads snapshot statistics from the given file and indexes it in Elasticsearch. Args: es_client (Elasticsearch): The Elasticsearch client. file_path (str): The path to the snapshot statistics file. Returns: Optional[str]: The snapshot ID if successful, None otherwise. """ logger.info(f"Loading snapshot statistics from {file_path}") try: with open(file_path, "r", encoding="utf-8") as f: snapshot_stats = json.load(f) snapshot_id = snapshot_stats.get("snapshot_id") if not snapshot_id: logger.error("Snapshot ID not found in snapshot statistics.") return None index_name = ".ml-anomalies-shared" # Extract ID from file name id_ = os.path.splitext(os.path.basename(file_path))[0].replace( "ml-anomalies-snapshot_doc_", "" ) response = es_client.index(index=index_name, body=snapshot_stats, id=id_) logger.info(f"Snapshot statistics indexed with ID: {response['_id']}") return snapshot_id except (ApiError, TransportError, IOError, json.JSONDecodeError) as e: logger.error(f"Error indexing snapshot statistics: {e}") return None def find_file(file_name: str, extracted_files: List[str]) -> Optional[str]: """Find a file in the extracted files list.""" for file in extracted_files: if file_name in file: return file return None def import_model_state( job_id: str, es_client: Elasticsearch, archive_path: str ) -> None: """ Imports the model state, job configuration, annotations, and input data from an archive to Elasticsearch. Args: job_id (str): The ID of the job to import. es_client (Elasticsearch): The Elasticsearch client. archive_path (str): The path to the tar.gz archive. """ safe_job_id = sanitize_filename(job_id) output_dir = f"extracted_{safe_job_id}" extracted_files = extract_archive(archive_path, output_dir) if not extracted_files: logger.error("No files extracted. Aborting import.") return input_index = None job_config_file = f"{safe_job_id}_config.json" snapshot_docs_file = f"{safe_job_id}_snapshot_docs.ndjson" input_file = f"{safe_job_id}_input.ndjson" snapshot_stats_file = None # Find snapshot statistics file for file in extracted_files: if "ml-anomalies-snapshot_doc" in file: snapshot_stats_file = file break # Load the input data if available file = find_file(input_file, extracted_files) if file: logger.info(f"Importing input data from {file}") input_index = f"{safe_job_id}-input" create_input_index(es_client, input_index) upload_data(es_client, input_index, file) else: logger.warning(f"Input data file {input_file} not found in the archive.") # Load the job configuration file = find_file(job_config_file, extracted_files) if file: logger.info(f"Importing job configuration from {file}") try: with open(file, "r", encoding="utf-8") as f: job_config = json.load(f) job_id = create_job_config(es_client, job_config, input_index) if not job_id: logger.error("Failed to create job configuration.") return except (IOError, json.JSONDecodeError) as e: logger.error(f"Error reading job configuration: {e}") return else: logger.error( f"Job configuration file {job_config_file} not found in the archive." ) return # Load the job snapshot docs file = find_file(snapshot_docs_file, extracted_files) if file: logger.info(f"Importing snapshots from {file}") upload_data(es_client, ".ml-state-write", file) else: logger.error( f"Snapshot documents file {snapshot_docs_file} not found in the archive." ) return # Load the snapshot stats if snapshot_stats_file: logger.info(f"Importing snapshot statistics from {snapshot_stats_file}") snapshot_id = load_snapshot_stats(es_client, snapshot_stats_file) if not snapshot_id: logger.error("Failed to load snapshot statistics.") return else: logger.error("Snapshot statistics file not found in the archive.") return # Revert the job to the snapshot time.sleep(2) # Wait for the snapshot to be indexed try: es_client.ml.revert_model_snapshot(job_id=job_id, snapshot_id=snapshot_id) logger.info(f"Reverted job {job_id} to snapshot {snapshot_id}") except (ApiError, TransportError) as e: logger.error(f"Error reverting job to snapshot: {e}") finally: # Clean up extracted files and directory try: for file in extracted_files: os.remove(file) os.rmdir(output_dir) except OSError as e: logger.warning(f"Error cleaning up extracted files: {e}") logger.info(f"Import of job {job_id} completed successfully.") def main() -> None: """Main function to import model state to Elasticsearch.""" parser = argparse.ArgumentParser(description="Import model state to Elasticsearch.") parser.add_argument( "--url", type=str, default="https://localhost:9200", help="Elasticsearch URL" ) parser.add_argument( "--username", type=str, required=True, help="Username for Elasticsearch authentication", ) parser.add_argument( "--password", type=str, help="Password for Elasticsearch authentication" ) parser.add_argument("--job_id", type=str, required=True, help="Job ID to import") parser.add_argument( "--archive_path", type=str, required=True, help="Path to the archive file" ) parser.add_argument("--cloud_id", type=str, help="Cloud ID for Elasticsearch") parser.add_argument( "--ignore_certs", action="store_true", help="Ignore SSL certificate verification", ) args = parser.parse_args() # Handle password securely if not args.password: args.password = getpass(prompt="Enter Elasticsearch password: ") # Validate archive_path if not os.path.isfile(args.archive_path): logger.error(f"Archive file {args.archive_path} does not exist.") return logger.info("Connecting to Elasticsearch") try: if args.cloud_id: logger.info("Connecting to Elasticsearch cloud using cloud_id") es_client = Elasticsearch( cloud_id=args.cloud_id, basic_auth=(args.username, args.password), verify_certs=(not args.ignore_certs), ) else: logger.info("Connecting to Elasticsearch using URL") es_client = Elasticsearch( [args.url], basic_auth=(args.username, args.password), verify_certs=(not args.ignore_certs), ) except (ApiError, TransportError) as e: logger.error(f"Failed to connect to Elasticsearch: {e}") return import_model_state(args.job_id, es_client, args.archive_path) if __name__ == "__main__": main()