in assets/large_language_models/rag/components/src/validate_deployments.py [0:0]
def validate_aoai_deployments(parser_args, check_completion, check_embeddings, activity_logger: Logger):
"""Poll or create deployments in AOAI."""
completion_params = {}
embedding_params = {}
llm_config = json.loads(parser_args.llm_config)
print(f"Using llm_config: {json.dumps(llm_config, indent=2)}")
connection_id_completion = os.environ.get(
"AZUREML_WORKSPACE_CONNECTION_ID_AOAI_COMPLETION", None)
connection_id_embedding = os.environ.get(
"AZUREML_WORKSPACE_CONNECTION_ID_AOAI_EMBEDDING", None)
activity_logger.info(
"[Validate Deployments]: Received and parsed arguments for validating deployments in RAG. "
+ " Starting validation now...")
if connection_id_completion and check_completion:
connection = get_connection_by_id_v2(connection_id_completion)
credential = workspace_connection_to_credential(connection)
if hasattr(credential, 'key'):
completion_params["deployment_id"] = llm_config["deployment_name"]
print(
f"Completion model deployment name: {completion_params['deployment_id']}")
completion_params["model_name"] = llm_config["model_name"]
print(
f"Completion model name: {completion_params['model_name']}")
completion_params["openai_api_key"] = credential.key
completion_params["openai_api_base"] = connection.target
connection_metadata = connection.metadata
completion_params["openai_api_type"] = connection_metadata.get(
'apiType',
connection_metadata.get('ApiType', "azure"))
completion_params["openai_api_version"] = connection_metadata.get(
'apiVersion',
connection_metadata.get('ApiVersion', "2023-03-15-preview"))
completion_params["connection"] = connection_id_completion
# Name is currently the only distinguishing factor between default and non-default
# Default connection is the only one which can perform control plane operations,
# as AI Studio does not allow selecting of ResourceID in their UI yet.
if is_default_connection(connection):
activity_logger.info(
"[Validate Deployments]: Completion model using Default AOAI connection, parsing ResourceId")
cog_workspace_details = split_details(
connection_metadata["ResourceId"], start=1)
completion_params["default_aoai_name"] = cog_workspace_details["accounts"]
completion_params["resource_group"] = cog_workspace_details["resourceGroups"]
if completion_params == {}:
activity_logger.info(
"ValidationFailed: Completion model LLM connection was unable to pull information")
raise Exception(
"Completion model LLM connection was unable to pull information")
activity_logger.info(
"[Validate Deployments]: Completion workspace connection retrieved and params populated successfully...",
extra={
'properties': {
'connection': connection_id_completion,
'openai_api_type': completion_params["openai_api_type"],
'model_name': completion_params["model_name"],
'deployment_name': completion_params["deployment_id"],
'is_default_aoai': "default_aoai_name" in completion_params}})
elif check_completion:
activity_logger.info(
"ValidationFailed: ConnectionID for LLM is empty and check_embeddings = True")
raise Exception(
"ConnectionID for LLM is empty and check_completion = True")
# Embedding connection will not be passed in for Existing ACS scenario
if connection_id_embedding and check_embeddings:
connection = get_connection_by_id_v2(connection_id_embedding)
credential = workspace_connection_to_credential(connection)
_, details = parser_args.embeddings_model.split('://')
if hasattr(credential, 'key'):
embedding_params["deployment_id"] = split_details(details, start=0)[
"deployment"]
print(
f"Embedding deployment name: {embedding_params['deployment_id']}")
embedding_params["model_name"] = split_details(details, start=0)[
"model"]
print(
f"Embedding model name: {embedding_params['model_name']}")
embedding_params["openai_api_key"] = credential.key
embedding_params["openai_api_base"] = connection.target
connection_metadata = connection.metadata
embedding_params["openai_api_type"] = connection_metadata.get(
'apiType',
connection_metadata.get('ApiType', "azure"))
embedding_params["openai_api_version"] = connection_metadata.get(
'apiVersion',
connection_metadata.get('ApiVersion', "2023-03-15-preview"))
embedding_params["connection"] = connection_id_embedding
if is_default_connection(connection):
activity_logger.info(
"[Validate Deployments]: Embedding model using Default AOAI connection, parsing ResourceId")
cog_workspace_details = split_details(
connection_metadata["ResourceId"], start=1)
embedding_params["default_aoai_name"] = cog_workspace_details["accounts"]
embedding_params["resource_group"] = cog_workspace_details["resourceGroups"]
print("Using workspace connection key for OpenAI embeddings")
if embedding_params == {}:
activity_logger.info(
"ValidationFailed: Embedding model connection was unable to pull information")
raise Exception(
"Embeddings connection was unable to pull information")
activity_logger.info(
"[Validate Deployments]: Embedding workspace connection retrieved and params populated successfully...",
extra={
'properties': {
'connection': connection_id_embedding,
'openai_api_type': embedding_params["openai_api_type"],
'model_name': embedding_params["model_name"],
'deployment_name': embedding_params["deployment_id"],
'is_default_aoai': "default_aoai_name" in embedding_params}})
elif check_embeddings:
activity_logger.info(
"ValidationFailed: ConnectionID for Embeddings is empty and check_embeddings = True")
raise Exception(
"ConnectionID for Embeddings is empty and check_embeddings = True")
poll_on_deployment(completion_params,
embedding_params, activity_logger)
# Dummy output to allow step ordering
with open(parser_args.output_data, "w") as f:
json.dump({"deployment_validation_success": "true"}, f)
activity_logger.info(
"[Validate Deployments]: Success! AOAI deployments have been validated.")