in scripts/share_ad_job_state/import_model_snapshot.py [0:0]
def main() -> None:
"""Main function to import model state to Elasticsearch."""
parser = argparse.ArgumentParser(description="Import model state to Elasticsearch.")
parser.add_argument(
"--url", type=str, default="https://localhost:9200", help="Elasticsearch URL"
)
parser.add_argument(
"--username",
type=str,
required=True,
help="Username for Elasticsearch authentication",
)
parser.add_argument(
"--password", type=str, help="Password for Elasticsearch authentication"
)
parser.add_argument("--job_id", type=str, required=True, help="Job ID to import")
parser.add_argument(
"--archive_path", type=str, required=True, help="Path to the archive file"
)
parser.add_argument("--cloud_id", type=str, help="Cloud ID for Elasticsearch")
parser.add_argument(
"--ignore_certs",
action="store_true",
help="Ignore SSL certificate verification",
)
args = parser.parse_args()
# Handle password securely
if not args.password:
args.password = getpass(prompt="Enter Elasticsearch password: ")
# Validate archive_path
if not os.path.isfile(args.archive_path):
logger.error(f"Archive file {args.archive_path} does not exist.")
return
logger.info("Connecting to Elasticsearch")
try:
if args.cloud_id:
logger.info("Connecting to Elasticsearch cloud using cloud_id")
es_client = Elasticsearch(
cloud_id=args.cloud_id,
basic_auth=(args.username, args.password),
verify_certs=(not args.ignore_certs),
)
else:
logger.info("Connecting to Elasticsearch using URL")
es_client = Elasticsearch(
[args.url],
basic_auth=(args.username, args.password),
verify_certs=(not args.ignore_certs),
)
except (ApiError, TransportError) as e:
logger.error(f"Failed to connect to Elasticsearch: {e}")
return
import_model_state(args.job_id, es_client, args.archive_path)