tools/database/measures.py (84 lines of code) (raw):

import os import time import logging import asyncio from typing import Any, Dict, List, Optional, Annotated import aiohttp from azure.identity import ChainedTokenCredential, ManagedIdentityCredential, AzureCliCredential from .types import MeasuresList, MeasureItem # ----------------------------------------------------------------------------- # Helper function to perform the Azure AI Search query (analogous to tables.py) # ----------------------------------------------------------------------------- async def _perform_search(body: Dict[str, Any], search_index: str) -> Dict[str, Any]: """ Executes a search query against the specified Azure AI Search index. Args: body (dict): The JSON body for the search request. search_index (str): The name of the search index to query. Returns: dict: The JSON response from the search service. """ search_service = os.getenv("AZURE_SEARCH_SERVICE") if not search_service: raise Exception("AZURE_SEARCH_SERVICE environment variable is not set.") search_api_version = os.getenv("AZURE_SEARCH_API_VERSION", "2024-07-01") # Build the search endpoint URL. search_endpoint = ( f"https://{search_service}.search.windows.net/indexes/{search_index}/docs/search" f"?api-version={search_api_version}" ) # Obtain an access token for the search service. try: credential = ChainedTokenCredential( ManagedIdentityCredential(), AzureCliCredential() ) azure_search_scope = "https://search.azure.com/.default" token = credential.get_token(azure_search_scope).token except Exception as e: logging.error("Error obtaining Azure Search token.", exc_info=True) raise Exception("Failed to obtain Azure Search token.") from e headers = { "Content-Type": "application/json", "Authorization": f"Bearer {token}" } async with aiohttp.ClientSession() as session: try: async with session.post(search_endpoint, headers=headers, json=body) as response: if response.status >= 400: text = await response.text() error_message = f"Status code: {response.status}. Error: {text}" logging.error(f"[measures] {error_message}") raise Exception(error_message) result = await response.json() return result except Exception as e: logging.error("Error during the search HTTP request.", exc_info=True) raise Exception("Failed to execute search query.") from e # ----------------------------------------------------------------------------- # Function to retrieve all measures info from the Azure AI Search index # ----------------------------------------------------------------------------- async def measures_retrieval( datasource: Annotated[str, "Name of the target datasource"] ) -> MeasuresList: """ Retrieve a list of measures filtering by the given datasource. Each entry will include the following fields: - name - description - datasource - type - source_table - data_type - source_model Returns: MeasuresList: Contains a list of MeasureItem objects and an optional error message. """ search_index = "nl2sql-measures" safe_datasource = datasource.replace("'", "''") filter_expression = f"datasource eq '{safe_datasource}'" body = { "search": "*", "filter": filter_expression, "select": "name, description, datasource, type, source_table, data_type, source_model", "top": 1000 # Adjust based on your expected document count. } logging.info(f"[measures] Querying Azure AI Search for measures in datasource '{datasource}'") measures_info: List[MeasureItem] = [] error_message: Optional[str] = None try: start_time = time.time() result = await _perform_search(body, search_index) elapsed = round(time.time() - start_time, 2) logging.info(f"[measures] Finished querying measures in {elapsed} seconds") for doc in result.get("value", []): measure_item = MeasureItem( name=doc.get("name", ""), description=doc.get("description", ""), datasource=doc.get("datasource", ""), type=doc.get("type", None), source_table=doc.get("source_table", None), data_type=doc.get("data_type", None), source_model=doc.get("source_model", None) ) measures_info.append(measure_item) except Exception as e: error_message = str(e) logging.error(f"[measures] Error querying measures: {error_message}") if not measures_info: return MeasuresList( measures=[], error=f"No datasource with name '{datasource}' was found. {error_message or ''}".strip() ) return MeasuresList(measures=measures_info, error=error_message)