tools/ragindex/vector_index_retrieval.py (275 lines of code) (raw):

import os import re import time import json import logging import asyncio from typing import Annotated, Optional, List, Dict, Any from urllib.parse import urlparse import aiohttp from azure.identity import ManagedIdentityCredential, AzureCliCredential, ChainedTokenCredential from connectors import AzureOpenAIClient from .types import ( VectorIndexRetrievalResult, MultimodalVectorIndexRetrievalResult, DataPointsResult, ) # ----------------------------------------------------------------------------- # Helper Functions # ----------------------------------------------------------------------------- async def _get_azure_search_token() -> str: """ Acquires an Azure Search access token using chained credentials. """ try: credential = ChainedTokenCredential( ManagedIdentityCredential(), AzureCliCredential() ) # Wrap the synchronous token acquisition in a thread. token_obj = await asyncio.to_thread(credential.get_token, "https://search.azure.com/.default") return token_obj.token except Exception as e: logging.error("Error obtaining Azure Search token.", exc_info=True) raise Exception("Failed to obtain Azure Search token.") from e async def _perform_search(url: str, headers: Dict[str, str], body: Dict[str, Any]) -> Dict[str, Any]: """ Performs an asynchronous HTTP POST request to the given URL with the provided headers and body. Returns the parsed JSON response. Raises: Exception: When the request fails or returns an error status. """ async with aiohttp.ClientSession() as session: try: async with session.post(url, headers=headers, json=body) as response: if response.status >= 400: text = await response.text() error_message = f"Error {response.status}: {text}" logging.error(f"[_perform_search] {error_message}") raise Exception(error_message) return await response.json() except Exception as e: logging.error("Error during asynchronous HTTP request.", exc_info=True) raise Exception("Failed to execute search query.") from e # ----------------------------------------------------------------------------- # Main Functions # ----------------------------------------------------------------------------- async def vector_index_retrieve( input: Annotated[ str, "An optimized query string based on the user's ask and conversation history, when available" ], security_ids: str = 'anonymous' ) -> Annotated[ VectorIndexRetrievalResult, "A Pydantic model containing the search results as a string" ]: """ Performs a vector search against Azure Cognitive Search and returns the results wrapped in a Pydantic model. If an error occurs, the 'error' field is populated. """ aoai = AzureOpenAIClient() search_top_k = os.getenv('AZURE_SEARCH_TOP_K', 3) search_approach = os.getenv('AZURE_SEARCH_APPROACH', 'hybrid') semantic_search_config = os.getenv('AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG', 'my-semantic-config') search_service = os.getenv('AZURE_SEARCH_SERVICE') search_index = os.getenv('AZURE_SEARCH_INDEX', 'ragindex') search_api_version = os.getenv('AZURE_SEARCH_API_VERSION', '2024-07-01') use_semantic = os.getenv('AZURE_SEARCH_USE_SEMANTIC', 'false').lower() == 'true' VECTOR_SEARCH_APPROACH = 'vector' TERM_SEARCH_APPROACH = 'term' HYBRID_SEARCH_APPROACH = 'hybrid' search_results: List[str] = [] error_message: Optional[str] = None search_query = input try: # Generate embeddings for the query. start_time = time.time() logging.info(f"[vector_index_retrieve] Generating question embeddings. Search query: {search_query}") # Wrap synchronous get_embeddings in a thread. embeddings_query = await asyncio.to_thread(aoai.get_embeddings, search_query) response_time = round(time.time() - start_time, 2) logging.info(f"[vector_index_retrieve] Finished generating embeddings in {response_time} seconds") # Acquire token for Azure Search. azure_search_token = await _get_azure_search_token() # Prepare the request body. body: Dict[str, Any] = { "select": "title, content, url, filepath, chunk_id", "top": search_top_k } if search_approach == TERM_SEARCH_APPROACH: body["search"] = search_query elif search_approach == VECTOR_SEARCH_APPROACH: body["vectorQueries"] = [{ "kind": "vector", "vector": embeddings_query, "fields": "contentVector", "k": int(search_top_k) }] elif search_approach == HYBRID_SEARCH_APPROACH: body["search"] = search_query body["vectorQueries"] = [{ "kind": "vector", "vector": embeddings_query, "fields": "contentVector", "k": int(search_top_k) }] # Apply security filter. filter_str = ( f"metadata_security_id/any(g:search.in(g, '{security_ids}')) " f"or not metadata_security_id/any()" ) body["filter"] = filter_str logging.debug(f"[vector_index_retrieve] Search filter: {filter_str}") headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {azure_search_token}' } search_url = ( f"https://{search_service}.search.windows.net/indexes/{search_index}/docs/search" f"?api-version={search_api_version}" ) # Execute the search query asynchronously. start_time = time.time() response_json = await _perform_search(search_url, headers, body) elapsed = round(time.time() - start_time, 2) logging.info(f"[vector_index_retrieve] Finished querying Azure Cognitive Search in {elapsed} seconds") if response_json.get('value'): logging.info(f"[vector_index_retrieve] {len(response_json['value'])} documents retrieved") for doc in response_json['value']: url = doc.get('url', '') uri = re.sub(r'https://[^/]+\.blob\.core\.windows\.net', '', url) content_str = doc.get('content', '').strip() search_results.append(f"{uri}: {content_str}\n") else: logging.info("[vector_index_retrieve] No documents retrieved") except Exception as e: error_message = f"Exception occurred: {e}" logging.error(f"[vector_index_retrieve] {error_message}", exc_info=True) # Join the retrieved results into a single string. sources = ' '.join(search_results) return VectorIndexRetrievalResult(result=sources, error=error_message) def extract_captions(str_captions): # Regular expression pattern to match image references followed by their descriptions pattern = r"\[.*?\]:\s(.*?)(?=\[.*?\]:|$)" # Find all matches matches = re.findall(pattern, str_captions, re.DOTALL) return [match.strip() for match in matches] def replace_image_filenames_with_urls(content: str, related_images: list) -> str: """ Replace image filenames or relative paths in the content string with their corresponding full URLs from the related_images list. """ for image_url in related_images: # Parse the URL and remove the leading slash from the path logging.debug(f"[multimodal_vector_index_retrieve] image_url: {image_url}.") parsed_url = urlparse(image_url) image_path = parsed_url.path.lstrip('/') # e.g., 'documents-images/myfolder/filename.png' logging.debug(f"[multimodal_vector_index_retrieve] image_path: {image_path}.") # Replace occurrences of the relative path in the content with the full URL content = content.replace(image_path, image_url) logging.debug(f"[multimodal_vector_index_retrieve] content: {content}.") # Also replace only the filename if it appears alone # filename = image_path.split('/')[-1] # content = content.replace(filename, image_url) return content async def multimodal_vector_index_retrieve( input: Annotated[ str, "An optimized query string based on the user's ask and conversation history, when available" ], security_ids: str = 'anonymous' ) -> Annotated[ MultimodalVectorIndexRetrievalResult, "A Pydantic model containing the search results with separate lists for texts and images" ]: """ Variation of vector_index_retrieve that fetches text and related images from the search index. Returns the results wrapped in a Pydantic model with separate lists for texts and images. """ aoai = AzureOpenAIClient() search_top_k = int(os.getenv('AZURE_SEARCH_TOP_K', 3)) search_approach = os.getenv('AZURE_SEARCH_APPROACH', 'vector') # or 'hybrid' semantic_search_config = os.getenv('AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG', 'my-semantic-config') search_service = os.getenv('AZURE_SEARCH_SERVICE') search_index = os.getenv('AZURE_SEARCH_INDEX', 'ragindex') search_api_version = os.getenv('AZURE_SEARCH_API_VERSION', '2024-07-01') use_semantic = os.getenv('AZURE_SEARCH_USE_SEMANTIC', 'false').lower() == 'true' logging.info(f"[multimodal_vector_index_retrieve] User input: {input}") text_results: List[str] = [] image_urls: List[List[str]] = [] captions: List[str] = [] error_message: Optional[str] = None # 1. Generate embeddings for the query. try: start_time = time.time() embeddings_query = await asyncio.to_thread(aoai.get_embeddings, input) embedding_time = round(time.time() - start_time, 2) logging.info(f"[multimodal_vector_index_retrieve] Query embeddings took {embedding_time} seconds") except Exception as e: error_message = f"Error generating embeddings: {e}" logging.error(f"[multimodal_vector_index_retrieve] {error_message}", exc_info=True) return MultimodalVectorIndexRetrievalResult( texts=[], images=[], error=error_message ) # 2. Acquire Azure Search token. try: azure_search_token = await _get_azure_search_token() except Exception as e: error_message = f"Error acquiring token for Azure Search: {e}" logging.error(f"[multimodal_vector_index_retrieve] {error_message}", exc_info=True) return MultimodalVectorIndexRetrievalResult( texts=[], images=[], error=error_message ) # 3. Build the request body. body: Dict[str, Any] = { "select": "title, content, filepath, url, imageCaptions, relatedImages", "top": search_top_k, "vectorQueries": [ { "kind": "vector", "vector": embeddings_query, "fields": "contentVector", "k": int(search_top_k) }, { "kind": "vector", "vector": embeddings_query, "fields": "captionVector", "k": int(search_top_k) } ] } if use_semantic and search_approach != "vector": body["queryType"] = "semantic" body["semanticConfiguration"] = semantic_search_config # Apply security filter. filter_str = ( f"metadata_security_id/any(g:search.in(g, '{security_ids}')) " "or not metadata_security_id/any()" ) body["filter"] = filter_str headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {azure_search_token}' } search_url = ( f"https://{search_service}.search.windows.net" f"/indexes/{search_index}/docs/search" f"?api-version={search_api_version}" ) # 4. Query Azure Search. try: start_time = time.time() response_json = await _perform_search(search_url, headers, body) response_time = round(time.time() - start_time, 2) logging.info(f"[multimodal_vector_index_retrieve] Finished querying Azure AI Search in {response_time} seconds") for doc in response_json.get('value', []): content = doc.get('content', '') str_captions = doc.get('imageCaptions', '') captions.append(extract_captions(str_captions)) url = doc.get('url', '') # Convert blob URL to relative path uri = re.sub(r'https://[^/]+\.blob\.core\.windows\.net', '', url) text_results.append(f"{uri}: {content.strip()}") # Replace image filenames with URLs content = replace_image_filenames_with_urls(content, doc.get('relatedImages', [])) # Extract image URLs from <figure> tags # doc_image_urls = re.findall(r'<figure>(https?://.*?)</figure>', content) # image_urls.append(doc_image_urls) image_urls.append(doc.get('relatedImages', [])) # Replace <figure>...</figure> with <img src="..."> # content = re.sub(r'<figure>(https?://\S+)</figure>', r'<img src="\1">', content) except Exception as e: error_message = f"Exception in retrieval: {e}" logging.error(f"[multimodal_vector_index_retrieve] {error_message}", exc_info=True) return MultimodalVectorIndexRetrievalResult( texts=text_results, images=image_urls, captions=captions, error=error_message ) def get_data_points_from_chat_log(chat_log: list) -> DataPointsResult: """ Parses a chat log to extract data points (e.g., filenames with extension) from tool call events. Returns a Pydantic model containing the list of extracted data points. """ # Regex patterns. request_call_id_pattern = re.compile(r"id='([^']+)'") request_function_name_pattern = re.compile(r"name='([^']+)'") exec_call_id_pattern = re.compile(r"call_id='([^']+)'") exec_content_pattern = re.compile(r"content='(.+?)', call_id=", re.DOTALL) # Allowed file extensions. allowed_extensions = ['vtt', 'xlsx', 'xls', 'pdf', 'docx', 'pptx', 'png', 'jpeg', 'jpg', 'bmp', 'tiff'] filename_pattern = re.compile( rf"([^\s:]+\.(?:{'|'.join(allowed_extensions)})\s*:\s*.*?)(?=[^\s:]+\.(?:{'|'.join(allowed_extensions)})\s*:|$)", re.IGNORECASE | re.DOTALL ) relevant_call_ids = set() data_points = [] for msg in chat_log: if msg["message_type"] == "ToolCallRequestEvent": content = msg["content"][0] call_id_match = request_call_id_pattern.search(content) function_name_match = request_function_name_pattern.search(content) if call_id_match and function_name_match: if function_name_match.group(1) == "vector_index_retrieve_wrapper": relevant_call_ids.add(call_id_match.group(1)) elif msg["message_type"] == "ToolCallExecutionEvent": content = msg["content"][0] call_id_match = exec_call_id_pattern.search(content) if call_id_match and call_id_match.group(1) in relevant_call_ids: content_part_match = exec_content_pattern.search(content) if not content_part_match: continue content_part = content_part_match.group(1) try: parsed = json.loads(content_part) texts = parsed.get("texts", []) except json.JSONDecodeError: texts = [re.split(r'["\']images["\']\s*:\s*\[', content_part, 1, re.IGNORECASE)[0]] for text in texts: text = bytes(text, "utf-8").decode("unicode_escape") for match in filename_pattern.findall(text): extracted = match.strip(" ,\\\"").lstrip("[").rstrip("],") if extracted: data_points.append(extracted) return DataPointsResult(data_points=data_points)