in scripts/share_ad_job_state/export_model_snapshot.py [0:0]
def main() -> None:
"""
Main function to extract model state from Elasticsearch.
Example usage:
python export_model_snapshot.py --url https://localhost:9200 --username user --job_id <job_id> --before_date 2023-05-10T00:00:00 --after_date 2023-01-01T00:00:00 --include_inputs
Output:
<job_id>_state.tar.gz archive with the following files:
- ml-anomalies-snapshot_doc_<id>.json snapshot stats document
- <job_id>_config.json job configuration
- <job_id>_snapshot_docs.ndjson snapshot documents
- <job_id>_annotations.ndjson annotations
- <job_id>_input.ndjson input data (if --include_inputs flag is set)
"""
parser = argparse.ArgumentParser(
description=(
"Extract model state from Elasticsearch. WARNING: This operation will extract data that may include PII."
)
)
parser.add_argument(
"--url",
type=str,
required=False,
default="https://localhost:9200",
help="Elasticsearch URL",
)
parser.add_argument(
"--username",
type=str,
required=True,
help="Username for Elasticsearch authentication",
)
parser.add_argument(
"--password",
type=str,
required=False,
help="Password for Elasticsearch authentication",
)
parser.add_argument(
"--job_id", type=str, required=True, help="Job ID to extract model state"
)
parser.add_argument(
"--cloud_id", type=str, required=False, help="Cloud ID for Elasticsearch"
)
parser.add_argument(
"--before_date",
type=validate_date,
required=False,
help="Search for the latest snapshot CREATED before the given date (format: YYYY-MM-DDTHH:MM:SS)",
)
parser.add_argument(
"--after_date",
type=validate_date,
required=False,
help="Search for input data and annotations after the specified date (format: YYYY-MM-DDTHH:MM:SS)",
)
parser.add_argument(
"--include_inputs",
action="store_true",
help="Include input data in the archive",
)
parser.add_argument(
"--ignore_certs",
action="store_true",
help="Disable SSL certificate verification",
)
args = parser.parse_args()
# Handle password securely
if not args.password:
args.password = getpass(prompt="Enter Elasticsearch password: ")
# Warn about PII data
logger.warning("This operation will extract data that may include PII.")
confirm = input("Do you wish to continue? (yes/no): ")
if confirm.lower() != "yes":
logger.info("Operation aborted by the user.")
return
logger.info("Connecting to Elasticsearch")
# Connect to an Elasticsearch instance
try:
if args.cloud_id:
logger.info("Connecting to Elasticsearch cloud using cloud_id")
es_client = Elasticsearch(
cloud_id=args.cloud_id,
basic_auth=(args.username, args.password),
verify_certs=(not args.ignore_certs),
)
else:
logger.info("Connecting to Elasticsearch using URL")
es_client = Elasticsearch(
[args.url],
basic_auth=(args.username, args.password),
verify_certs=(not args.ignore_certs),
)
except (ApiError, TransportError) as e:
logger.error(f"Failed to connect to Elasticsearch: {e}")
return
snapshot_info = get_snapshot_info(es_client, args.job_id, args.before_date)
if snapshot_info is None:
logger.error("Failed to retrieve snapshot info.")
return
snapshot_id, snapshot_doc_count = snapshot_info
# Get the snapshot document count and store the result
file_name_ml_anomalies = save_snapshot_stats(args.job_id, snapshot_id, es_client)
# Get the job configuration and store it
job_config_result = save_job_config(args.job_id, es_client)
if job_config_result:
file_name_job_config, job_configuration = job_config_result
else:
logger.error("Failed to retrieve job configuration.")
return
# Get the annotations and store them
file_name_annotations = save_annotations(
args.job_id, args.before_date, args.after_date, es_client
)
# Get the input data and store it
if args.include_inputs:
file_name_inputs = (
save_inputs(job_configuration, args.before_date, args.after_date, es_client)
if job_configuration
else None
)
else:
logger.info("Input data will not be included in the archive.")
file_name_inputs = None
# Call the function to extract model state
filename_snapshots = save_snapshots(
args.job_id, snapshot_id, es_client, snapshot_doc_count
)
# Create an archive with all generated files
files_to_archive = [
file_name_ml_anomalies,
file_name_job_config,
filename_snapshots,
file_name_annotations,
file_name_inputs,
]
create_archive(args.job_id, files_to_archive)