packages/python-packages/apiview-copilot/src/_search_manager.py (268 lines of code) (raw):

from azure.cosmos import CosmosClient from azure.search.documents import SearchClient, SearchItemPaged from azure.search.documents.models import ( VectorizableTextQuery, QueryType, QueryAnswerType, QueryAnswerResult, QueryCaptionType, SemanticErrorMode, ) from azure.identity import DefaultAzureCredential from src._models import Guideline, Example from collections import deque import copy import json import os from typing import List, Dict if "APPSETTING_WEBSITE_SITE_NAME" not in os.environ: # running on dev machine, loadenv import dotenv dotenv.load_dotenv() # Cosmos DB COSMOS_ACC_NAME = os.environ.get("AZURE_COSMOS_ACC_NAME") COSMOS_DB_NAME = os.environ.get("AZURE_COSMOS_DB_NAME") COSMOS_ENDPOINT = f"https://{COSMOS_ACC_NAME}.documents.azure.com:443/" # Azure AI Search AZURE_SEARCH_NAME = os.environ.get("AZURE_SEARCH_NAME") SEARCH_ENDPOINT = f"https://{AZURE_SEARCH_NAME}.search.windows.net" CREDENTIAL = DefaultAzureCredential() _PACKAGE_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) _GUIDELINES_FOLDER = os.path.join(_PACKAGE_ROOT, "guidelines") class SearchItem: """ Represents a single search result item. """ def __init__(self, result: Dict): self.id = result.get("id") self.text = result.get("chunk") self.lang = result.get("lang") self.title = result.get("title") self.score = result.get("@search.score") self.reranker_score = result.get("@search.reranker_score") self.captions = [] for caption in result.get("@search.captions", []): self.captions.append(SearchCaption(caption)) class SearchAnswer: """ Represents a single answer from the search results. """ def __init__(self, result: QueryAnswerResult): self.text = result.text self.score = result.score self.highlights = result.highlights class SearchCaption: """ Represents a single caption from the search results. """ def __init__(self, result: Dict): self.text = result.text self.highlights = result.highlights class SearchResult: """ Represents the search results. """ def __init__(self, search_results: SearchItemPaged[Dict]): result_list = list(search_results) self.results = [] self.answers = [] for result in result_list: self.results.append(SearchItem(result)) for answer in search_results.get_answers(): self.answers.append(SearchAnswer(answer)) def __len__(self): return len(self.results) def __iter__(self): return iter(self.results) class Context: """ Represents the resolved context of a search with objects from the CosmosDB database. """ items: List["ContextItem"] def __init__( self, *, guidelines: List[Guideline] = None, examples: List[Example] = None, ): example_dict = {x.id: x for x in examples} self.items = [] for guideline in guidelines: item = ContextItem(guideline, example_dict) self.items.append(item) def __iter__(self): """ Returns an iterator over the context items. """ for item in self.items: yield item def __len__(self): """ Returns the number of items in the context. """ return len(self.items) def __repr__(self): return f"Context(items={len(self.items)}" def to_markdown(self) -> str: """ Converts the context to a markdown string. """ markdown = "" for item in self.items: markdown += f"\n{item.to_markdown()}" return markdown class ContextItem: """ Represents a single item in the context. """ def __init__(self, result: Guideline, examples: Dict[str, Example]): self.id = self._process_id(result.id) self.content = result.content self.lang = result.lang self.title = result.title self.examples = [] for ex_id in result.related_examples or []: # copy the example to a new object example = copy.deepcopy(examples.get(ex_id)) if example is not None: del example.id del example.guideline_ids self.examples.append(example) else: print(f"WARNING: Example {ex_id} not found for guideline {result.id}. Skipping.") def _process_id(self, id: str) -> str: """ Processes the ID to convert the Search-compatible values with web-compatible ones. """ return id.replace("=html=", ".html#") def to_markdown(self) -> str: """ Converts the context item to a markdown string. """ markdown = f"## {self.title} [id]({self.id})\n\n{self.content}\n\n" if self.examples: # collect good and bad examples separately good_examples = [] bad_examples = [] for example in self.examples: if example.example_type == "good": good_examples.append(example) else: bad_examples.append(example) if good_examples: markdown += "### GOOD Examples\n\n" for example in good_examples: markdown += f"```python\n{example.content}\n```\n\n" markdown += f"{example.explanation}\n\n" if bad_examples: markdown += "### BAD Examples\n\n" for example in bad_examples: markdown += f"```python\n{example.content}\n```\n\n" markdown += f"{example.explanation}\n\n" return markdown class SearchManager: def __init__(self, *, language: str, include_general_guidelines: bool = False): self.language = language self.filter_expression = f"lang eq '{language}'" if include_general_guidelines: self.filter_expression += " or lang eq '' or lang eq null" self.static_guidelines = self._retrieve_static_guidelines( language, include_general_guidelines=include_general_guidelines ) self._static_guidelines_map = {x["id"]: x for x in self.static_guidelines} def _ensure_env_vars(self, vars: List[str]): """ Ensures that the given environment variables are set. """ missing = [] for var in vars: if os.getenv(var) is None: missing.append(var) if missing: raise ValueError(f"Environment variables not set: {', '.join(missing)}") def _retrieve_static_guidelines(self, language, include_general_guidelines: bool = False) -> List[object]: """ Retrieves the guidelines for the given language, optional with general guidelines. This method retrieves guidelines statically from the file system. It does not query any Azure service. """ general_guidelines = [] if include_general_guidelines: general_guidelines_path = os.path.join(_GUIDELINES_FOLDER, "general") for filename in os.listdir(general_guidelines_path): with open(os.path.join(general_guidelines_path, filename), "r") as f: items = json.loads(f.read()) general_guidelines.extend(items) language_guidelines = [] language_guidelines_path = os.path.join(_GUIDELINES_FOLDER, language) try: for filename in os.listdir(language_guidelines_path): with open(os.path.join(language_guidelines_path, filename), "r") as f: items = json.loads(f.read()) language_guidelines.extend(items) except FileNotFoundError: print(f"WARNING: No guidelines found for language {language}.") return [] return general_guidelines + language_guidelines def search_guidelines(self, query: str) -> SearchResult: """ Searches the guidelines index for the given query and returns the results as a SearchResult object. """ self._ensure_env_vars(["AZURE_SEARCH_NAME"]) client = SearchClient( endpoint=SEARCH_ENDPOINT, index_name="guidelines-index", credential=CREDENTIAL, ) result = client.search( search_text=query, filter=self.filter_expression, semantic_configuration_name="archagent-semantic-search-guidelines", semantic_error_mode=SemanticErrorMode.FAIL, query_type=QueryType.SEMANTIC, query_caption=QueryCaptionType.EXTRACTIVE, query_answer=QueryAnswerType.EXTRACTIVE, top=10, vector_queries=[VectorizableTextQuery(text=query, fields="text_vector")], ) return SearchResult(result) def search_examples(self, query: str) -> SearchResult: """ Searches the examples index for the given query and returns the results as a SearchResult object. """ self._ensure_env_vars(["AZURE_SEARCH_NAME"]) client = SearchClient(endpoint=SEARCH_ENDPOINT, index_name="examples-index", credential=CREDENTIAL) result = client.search( search_text=query, filter=self.filter_expression, semantic_configuration_name="archagent-semantic-search-examples", semantic_error_mode=SemanticErrorMode.FAIL, query_type=QueryType.SEMANTIC, query_caption=QueryCaptionType.EXTRACTIVE, query_answer=QueryAnswerType.EXTRACTIVE, top=10, vector_queries=[VectorizableTextQuery(text=query, fields="text_vector")], ) return SearchResult(result) def guidelines_for_ids(self, ids: List[str]) -> List[object]: """ Retrieves the guidelines for the given IDs. This method retrieves guidelines statically from the file system. It does not query any Azure service. """ guidelines = [] for id in set(ids): guidelines.append(self._static_guidelines_map.get(id)) return guidelines def build_context(self, guideline_results: List[SearchResult], example_results: List[SearchResult]) -> Context: self._ensure_env_vars(["AZURE_COSMOS_ACC_NAME", "AZURE_COSMOS_DB_NAME"]) client = CosmosClient(COSMOS_ENDPOINT, credential=CREDENTIAL) database = client.get_database_client(COSMOS_DB_NAME) guidelines_container = database.get_container_client("guidelines") examples_container = database.get_container_client("examples") # initial ids from the search queries starting_example_ids = list(set([x.id for x in example_results])) starting_guideline_ids = list(set([x.id for x in guideline_results])) # track processed IDs to avoid loops seen_guideline_ids = set() seen_example_ids = set() # track the final results final_guidelines = {} final_examples = {} for ex in starting_example_ids: final_examples[ex] = None # queue for BFS traversal queue = deque(starting_guideline_ids) batch_size = 50 def batch_query(container: CosmosClient, id_list: List[str]) -> List[object]: """ Helper function to batch query the container. """ results = [] for i in range(0, len(id_list), batch_size): batch = id_list[i : i + batch_size] placeholders = ",".join([f"@id{i}" for i in range(len(batch))]) query = f"SELECT * FROM c WHERE c.id IN ({placeholders})" parameters = [{"name": f"@id{i}", "value": value} for i, value in enumerate(batch)] results.extend( list( container.query_items( query=query, parameters=parameters, enable_cross_partition_query=True, ) ) ) return results while queue: batch_ids = list( set([queue.popleft() for _ in range(min(batch_size, len(queue))) if _ not in seen_guideline_ids]) ) if not batch_ids: continue guidelines = batch_query(guidelines_container, batch_ids) for guideline in guidelines: gid = guideline["id"] if gid in seen_guideline_ids: continue seen_guideline_ids.add(gid) final_guidelines[gid] = guideline # queue up related guidelines for rel in guideline.get("related_guidelines") or []: if rel not in seen_guideline_ids: queue.append(rel) # now do the same for examples for ex in guideline.get("related_examples") or []: try: if ex not in seen_example_ids: seen_example_ids.add(ex) final_examples[ex] = None except TypeError: # FIXME: This shouldn't happen once the data integrity is cleaned up print(f"WARNING: Examples for guideline {gid} is not a string! Skipping.") continue # now resolve all examples example_ids_to_lookup = [eid for eid, val in final_examples.items() if val is None] examples = batch_query(examples_container, example_ids_to_lookup) for ex in examples: ex_id = ex["id"] final_examples[ex_id] = ex # queue up more related guidelines from the example for gid in ex.get("guideline_ids", []): if gid not in seen_guideline_ids: queue.append(gid) # flatten the results to just the values final_guidelines = [Guideline(**v) for v in final_guidelines.values() if v is not None] final_examples = [Example(**v) for v in final_examples.values() if v is not None] context = Context(guidelines=final_guidelines, examples=final_examples) return context