tools/database/tables.py (202 lines of code) (raw):

import os import time import logging import asyncio from typing import Any, Dict, Annotated, Optional, List import aiohttp from azure.identity import ChainedTokenCredential, ManagedIdentityCredential, AzureCliCredential # Import Pydantic models from your types file. from .types import ( TablesList, TableItem, SchemaInfo, TablesRetrievalResult, TableRetrievalItem ) # Import the AzureOpenAIClient for generating embeddings. from connectors import AzureOpenAIClient # ----------------------------------------------------------------------------- # Helper function to perform the Azure AI Search query using aiohttp # ----------------------------------------------------------------------------- async def _perform_search(body: Dict[str, Any], search_index: str) -> Dict[str, Any]: """ Executes a search query against the specified Azure AI Search index. Args: body (dict): The JSON body for the search request. search_index (str): The name of the search index to query. Returns: dict: The JSON response from the search service. Raises: Exception: If the search query fails or an error occurs obtaining the token. """ search_service = os.getenv("AZURE_SEARCH_SERVICE") if not search_service: raise Exception("AZURE_SEARCH_SERVICE environment variable is not set.") search_api_version = os.getenv("AZURE_SEARCH_API_VERSION", "2024-07-01") # Build the search endpoint URL. search_endpoint = ( f"https://{search_service}.search.windows.net/indexes/{search_index}/docs/search" f"?api-version={search_api_version}" ) # Obtain an access token for the search service. try: credential = ChainedTokenCredential( ManagedIdentityCredential(), AzureCliCredential() ) azure_search_scope = "https://search.azure.com/.default" token = credential.get_token(azure_search_scope).token except Exception as e: logging.error("Error obtaining Azure Search token.", exc_info=True) raise Exception("Failed to obtain Azure Search token.") from e headers = { "Content-Type": "application/json", "Authorization": f"Bearer {token}" } # Perform the asynchronous HTTP POST request. async with aiohttp.ClientSession() as session: try: async with session.post(search_endpoint, headers=headers, json=body) as response: if response.status >= 400: text = await response.text() error_message = f"Status code: {response.status}. Error: {text}" logging.error(f"[tables] {error_message}") raise Exception(error_message) result = await response.json() return result except Exception as e: logging.error("Error during the search HTTP request.", exc_info=True) raise Exception("Failed to execute search query.") from e # ----------------------------------------------------------------------------- # Function to retrieve all tables info from the Azure AI Search index # ----------------------------------------------------------------------------- async def get_all_tables_info( datasource: Annotated[str, "Name of the target datasource"] ) -> TablesList: """ Retrieve a list of tables filtering by the given datasource. Each entry will have "table", "description", and "datasource". Returns: TablesList: Contains a list of TableItem objects and an optional error message. """ search_index = "nl2sql-tables" safe_datasource = datasource.replace("'", "''") filter_expression = f"datasource eq '{safe_datasource}'" body = { "search": "*", "filter": filter_expression, "select": "table, description, datasource", "top": 1000 # Adjust based on your expected document count. } logging.info(f"[tables] Querying Azure AI Search for tables in datasource '{datasource}'") tables_info: List[TableItem] = [] error_message: Optional[str] = None try: start_time = time.time() result = await _perform_search(body, search_index) elapsed = round(time.time() - start_time, 2) logging.info(f"[tables] Finished querying tables in {elapsed} seconds") for doc in result.get("value", []): table_item = TableItem( table=doc.get("table", ""), description=doc.get("description", ""), datasource=doc.get("datasource", "") ) tables_info.append(table_item) except Exception as e: error_message = str(e) logging.error(f"[tables] Error querying tables: {error_message}") if not tables_info: return TablesList( tables=[], error=f"No datasource with name '{datasource}' was found. {error_message or ''}".strip() ) return TablesList(tables=tables_info, error=error_message) # ----------------------------------------------------------------------------- # Function to retrieve schema information for a given table from the index # ----------------------------------------------------------------------------- async def get_schema_info( datasource: Annotated[str, "Target datasource"], table_name: Annotated[str, "Target table"] ) -> SchemaInfo: """ Retrieve schema information for a specific table in a given datasource. Returns the table's description and its columns. Returns: SchemaInfo: Contains the schema details or an error message. """ search_index = "nl2sql-tables" safe_datasource = datasource.replace("'", "''") safe_table_name = table_name.replace("'", "''") filter_expression = f"datasource eq '{safe_datasource}' and table eq '{safe_table_name}'" body = { "search": "*", "filter": filter_expression, "select": "table, description, datasource, columns", "top": 1 } logging.info(f"[tables] Querying Azure AI Search for schema info for table '{table_name}' in datasource '{datasource}'") error_message: Optional[str] = None try: start_time = time.time() result = await _perform_search(body, search_index) elapsed = round(time.time() - start_time, 2) logging.info(f"[tables] Finished querying schema info in {elapsed} seconds") docs = result.get("value", []) if not docs: error_message = f"Table '{table_name}' not found in datasource '{datasource}'." return SchemaInfo( datasource=datasource, table=table, error=error_message, columns=None ) doc = docs[0] columns_data = doc.get("columns", []) columns: Dict[str, str] = {} if isinstance(columns_data, list): for col in columns_data: col_name = col.get("name") col_description = col.get("description", "") if col_name: columns[col_name] = col_description return SchemaInfo( datasource=datasource, table=doc.get("table", table_name), description=doc.get("description", ""), columns=columns ) except Exception as e: error_message = str(e) logging.error(f"[tables] Error querying schema info: {error_message}") return SchemaInfo( datasource=datasource, table=table_name, error=error_message, columns=None ) # --------------------------------------------------------------------------- # Function to retrieve necessary tables from the retrieval system # based on an optimized input query, to construct a response to the user's request. # --------------------------------------------------------------------------- async def tables_retrieval( input: Annotated[str, "A query string optimized to retrieve necessary tables from the retrieval system to construct a response"], datasource: Annotated[Optional[str], "Target datasource"] = None ) -> TablesRetrievalResult: """ Retrieves necessary tables from the retrieval system based on the input query. Returns: TablesRetrievalResult: An object containing a list of TableRetrievalItem objects. If an error occurs, the 'error' field is populated. """ # Read search configuration from environment variables. search_index = os.getenv("NL2SQL_TABLES_INDEX", "nl2sql-tables") search_approach = os.getenv("AZURE_SEARCH_APPROACH", "hybrid") search_top_k = 10 use_semantic = os.getenv("AZURE_SEARCH_USE_SEMANTIC", "false").lower() == "true" semantic_search_config = os.getenv("AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG", "my-semantic-config") search_query = input # The optimized query string. search_results: List[TableRetrievalItem] = [] error_message: Optional[str] = None try: # Generate embeddings for the search query using the Azure OpenAI Client. aoai = AzureOpenAIClient() logging.info(f"[tables] Generating question embeddings. Search query: {search_query}") embeddings_query = await asyncio.to_thread(aoai.get_embeddings, search_query) logging.info("[tables] Finished generating question embeddings.") # Prepare the request body. body: Dict[str, Any] = { "select": "table, description", "top": search_top_k } # Apply datasource filter if provided. if datasource: body["filter"] = f"datasource eq '{datasource}'" # Adjust the body based on the search approach. if search_approach.lower() == "term": body["search"] = search_query elif search_approach.lower() == "vector": body["vectorQueries"] = [{ "kind": "vector", "vector": embeddings_query, "fields": "contentVector", "k": int(search_top_k) }] elif search_approach.lower() == "hybrid": body["search"] = search_query body["vectorQueries"] = [{ "kind": "vector", "vector": embeddings_query, "fields": "contentVector", "k": int(search_top_k) }] # If semantic search is enabled and we're not using vector-only search. if use_semantic and search_approach.lower() != "vector": body["queryType"] = "semantic" body["semanticConfiguration"] = semantic_search_config logging.info(f"[tables] Querying Azure AI Search for tables. Search query: {search_query}") start_time = time.time() result = await _perform_search(body, search_index) elapsed = round(time.time() - start_time, 2) logging.info(f"[tables] Finished querying Azure AI Search in {elapsed} seconds") # Process the returned documents. if result.get("value"): logging.info(f"[tables] {len(result['value'])} documents retrieved") for doc in result["value"]: table_name = doc.get("table", "") description = doc.get("description", "") search_results.append(TableRetrievalItem( table=table_name, description=description, datasource=datasource )) else: logging.info("[tables] No documents retrieved") except Exception as e: error_message = str(e) logging.error(f"[tables] Error when retrieving tables: {error_message}") return TablesRetrievalResult(tables=search_results, error=error_message)