code/embedding-function/utilities/search/azure_search_handler.py (173 lines of code) (raw):

import logging from typing import List from .search_handler_base import SearchHandlerBase from ..helpers.llm_helper import LLMHelper from ..helpers.azure_computer_vision_client import AzureComputerVisionClient from ..helpers.azure_search_helper import AzureSearchHelper from ..common.source_document import SourceDocument import json from azure.search.documents.models import VectorizedQuery import tiktoken logger = logging.getLogger(__name__) class AzureSearchHandler(SearchHandlerBase): _ENCODER_NAME = "cl100k_base" def __init__(self, env_helper): super().__init__(env_helper) self.llm_helper = LLMHelper() self.azure_computer_vision_client = AzureComputerVisionClient(env_helper) def create_search_client(self): return AzureSearchHelper().get_search_client() def perform_search(self, filename): return self.search_client.search( "*", select="title, content, metadata", filter=f"title eq '{filename}'" ) def process_results(self, results): logger.info("Processing search results") if results is None: logger.warning("No results found") return [] data = [ # Note that images uploaded with advanced image processing do not have a chunk ID [json.loads(result["metadata"]).get("chunk", i), result["content"]] for i, result in enumerate(results) ] logger.info("Processed results") return data def get_files(self): return self.search_client.search( "*", select="id, title", include_total_count=True ) def output_results(self, results): files = {} for result in results: id = result["id"] filename = result["title"] if filename in files: files[filename].append(id) else: files[filename] = [id] return files def delete_files(self, files): ids_to_delete = [] files_to_delete = [] for filename, ids in files.items(): files_to_delete.append(filename) ids_to_delete += [{"id": id} for id in ids] self.search_client.delete_documents(ids_to_delete) return ", ".join(files_to_delete) def search_by_blob_url(self, blob_url): return self.search_client.search( "*", select="id, title", include_total_count=True, filter=f"source eq '{blob_url}_SAS_TOKEN_PLACEHOLDER_'", ) def search_by_sharepoint_file_id(self, sharepoint_file_id): return self.search_client.search( "*", select="id, title", include_total_count=True, filter=f"sharepoint_file_id eq '{sharepoint_file_id}'", ) def query_search(self, question) -> List[SourceDocument]: logger.info(f"Performing query search for question: {question}") encoding = tiktoken.get_encoding(self._ENCODER_NAME) tokenised_question = encoding.encode(question) if self.env_helper.USE_ADVANCED_IMAGE_PROCESSING: logger.info("Using advanced image processing for vectorization") vectorized_question = self.azure_computer_vision_client.vectorize_text( question ) else: logger.info("Skipping advanced image processing") vectorized_question = None if self.env_helper.AZURE_SEARCH_USE_SEMANTIC_SEARCH: logger.info("Performing semantic search") results = self._semantic_search( question, tokenised_question, vectorized_question ) else: logger.info("Performing hybrid search") results = self._hybrid_search( question, tokenised_question, vectorized_question ) logger.info("Converting search results to SourceDocument list") return self._convert_to_source_documents(results) def _semantic_search( self, question: str, tokenised_question: list[int], vectorized_question: list[float] | None, ): return self.search_client.search( search_text=question, vector_queries=[ VectorizedQuery( vector=self.llm_helper.generate_embeddings(tokenised_question), k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K, fields=self._VECTOR_FIELD, ), *( [ VectorizedQuery( vector=vectorized_question, k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K, fields=self._IMAGE_VECTOR_FIELD, ) ] if vectorized_question is not None else [] ), ], filter=self.env_helper.AZURE_SEARCH_FILTER, query_type="semantic", semantic_configuration_name=self.env_helper.AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG, query_caption="extractive", query_answer="extractive", top=self.env_helper.AZURE_SEARCH_TOP_K, ) def _hybrid_search( self, question: str, tokenised_question: list[int], vectorized_question: list[float] | None, ): return self.search_client.search( search_text=question, vector_queries=[ VectorizedQuery( vector=self.llm_helper.generate_embeddings(tokenised_question), k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K, filter=self.env_helper.AZURE_SEARCH_FILTER, fields=self._VECTOR_FIELD, ), *( [ VectorizedQuery( vector=vectorized_question, k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K, fields=self._IMAGE_VECTOR_FIELD, ) ] if vectorized_question is not None else [] ), ], query_type="simple", # this is the default value filter=self.env_helper.AZURE_SEARCH_FILTER, top=self.env_helper.AZURE_SEARCH_TOP_K, ) def _convert_to_source_documents(self, search_results) -> List[SourceDocument]: source_documents = [] for source in search_results: source_documents.append( SourceDocument( id=source.get("id"), content=source.get("content"), title=source.get("title"), source=source.get("source"), chunk=source.get("chunk"), offset=source.get("offset"), page_number=source.get("page_number"), ) ) return source_documents