def main()

in scripts/share_ad_job_state/import_model_snapshot.py [0:0]


def main() -> None:
    """Main function to import model state to Elasticsearch."""
    parser = argparse.ArgumentParser(description="Import model state to Elasticsearch.")
    parser.add_argument(
        "--url", type=str, default="https://localhost:9200", help="Elasticsearch URL"
    )
    parser.add_argument(
        "--username",
        type=str,
        required=True,
        help="Username for Elasticsearch authentication",
    )
    parser.add_argument(
        "--password", type=str, help="Password for Elasticsearch authentication"
    )
    parser.add_argument("--job_id", type=str, required=True, help="Job ID to import")
    parser.add_argument(
        "--archive_path", type=str, required=True, help="Path to the archive file"
    )
    parser.add_argument("--cloud_id", type=str, help="Cloud ID for Elasticsearch")
    parser.add_argument(
        "--ignore_certs",
        action="store_true",
        help="Ignore SSL certificate verification",
    )
    args = parser.parse_args()

    # Handle password securely
    if not args.password:
        args.password = getpass(prompt="Enter Elasticsearch password: ")

    # Validate archive_path
    if not os.path.isfile(args.archive_path):
        logger.error(f"Archive file {args.archive_path} does not exist.")
        return

    logger.info("Connecting to Elasticsearch")
    try:
        if args.cloud_id:
            logger.info("Connecting to Elasticsearch cloud using cloud_id")
            es_client = Elasticsearch(
                cloud_id=args.cloud_id,
                basic_auth=(args.username, args.password),
                verify_certs=(not args.ignore_certs),
            )
        else:
            logger.info("Connecting to Elasticsearch using URL")
            es_client = Elasticsearch(
                [args.url],
                basic_auth=(args.username, args.password),
                verify_certs=(not args.ignore_certs),
            )
    except (ApiError, TransportError) as e:
        logger.error(f"Failed to connect to Elasticsearch: {e}")
        return

    import_model_state(args.job_id, es_client, args.archive_path)