def main()

in scripts/share_ad_job_state/export_model_snapshot.py [0:0]


def main() -> None:
    """
    Main function to extract model state from Elasticsearch.

    Example usage:
    python export_model_snapshot.py --url https://localhost:9200 --username user --job_id <job_id> --before_date 2023-05-10T00:00:00 --after_date 2023-01-01T00:00:00 --include_inputs

    Output:
    <job_id>_state.tar.gz archive with the following files:
    - ml-anomalies-snapshot_doc_<id>.json snapshot stats document
    - <job_id>_config.json job configuration
    - <job_id>_snapshot_docs.ndjson snapshot documents
    - <job_id>_annotations.ndjson annotations
    - <job_id>_input.ndjson input data (if --include_inputs flag is set)
    """
    parser = argparse.ArgumentParser(
        description=(
            "Extract model state from Elasticsearch. WARNING: This operation will extract data that may include PII."
        )
    )
    parser.add_argument(
        "--url",
        type=str,
        required=False,
        default="https://localhost:9200",
        help="Elasticsearch URL",
    )
    parser.add_argument(
        "--username",
        type=str,
        required=True,
        help="Username for Elasticsearch authentication",
    )
    parser.add_argument(
        "--password",
        type=str,
        required=False,
        help="Password for Elasticsearch authentication",
    )
    parser.add_argument(
        "--job_id", type=str, required=True, help="Job ID to extract model state"
    )
    parser.add_argument(
        "--cloud_id", type=str, required=False, help="Cloud ID for Elasticsearch"
    )
    parser.add_argument(
        "--before_date",
        type=validate_date,
        required=False,
        help="Search for the latest snapshot CREATED before the given date (format: YYYY-MM-DDTHH:MM:SS)",
    )
    parser.add_argument(
        "--after_date",
        type=validate_date,
        required=False,
        help="Search for input data and annotations after the specified date (format: YYYY-MM-DDTHH:MM:SS)",
    )
    parser.add_argument(
        "--include_inputs",
        action="store_true",
        help="Include input data in the archive",
    )
    parser.add_argument(
        "--ignore_certs",
        action="store_true",
        help="Disable SSL certificate verification",
    )
    args = parser.parse_args()

    # Handle password securely
    if not args.password:
        args.password = getpass(prompt="Enter Elasticsearch password: ")

    # Warn about PII data
    logger.warning("This operation will extract data that may include PII.")
    confirm = input("Do you wish to continue? (yes/no): ")
    if confirm.lower() != "yes":
        logger.info("Operation aborted by the user.")
        return

    logger.info("Connecting to Elasticsearch")
    # Connect to an Elasticsearch instance
    try:
        if args.cloud_id:
            logger.info("Connecting to Elasticsearch cloud using cloud_id")
            es_client = Elasticsearch(
                cloud_id=args.cloud_id,
                basic_auth=(args.username, args.password),
                verify_certs=(not args.ignore_certs),
            )
        else:
            logger.info("Connecting to Elasticsearch using URL")
            es_client = Elasticsearch(
                [args.url],
                basic_auth=(args.username, args.password),
                verify_certs=(not args.ignore_certs),
            )
    except (ApiError, TransportError) as e:
        logger.error(f"Failed to connect to Elasticsearch: {e}")
        return

    snapshot_info = get_snapshot_info(es_client, args.job_id, args.before_date)
    if snapshot_info is None:
        logger.error("Failed to retrieve snapshot info.")
        return
    snapshot_id, snapshot_doc_count = snapshot_info

    # Get the snapshot document count and store the result
    file_name_ml_anomalies = save_snapshot_stats(args.job_id, snapshot_id, es_client)

    # Get the job configuration and store it
    job_config_result = save_job_config(args.job_id, es_client)
    if job_config_result:
        file_name_job_config, job_configuration = job_config_result
    else:
        logger.error("Failed to retrieve job configuration.")
        return

    # Get the annotations and store them
    file_name_annotations = save_annotations(
        args.job_id, args.before_date, args.after_date, es_client
    )

    # Get the input data and store it
    if args.include_inputs:
        file_name_inputs = (
            save_inputs(job_configuration, args.before_date, args.after_date, es_client)
            if job_configuration
            else None
        )
    else:
        logger.info("Input data will not be included in the archive.")
        file_name_inputs = None

    # Call the function to extract model state
    filename_snapshots = save_snapshots(
        args.job_id, snapshot_id, es_client, snapshot_doc_count
    )

    # Create an archive with all generated files
    files_to_archive = [
        file_name_ml_anomalies,
        file_name_job_config,
        filename_snapshots,
        file_name_annotations,
        file_name_inputs,
    ]
    create_archive(args.job_id, files_to_archive)