datastore/providers/azuresearch_datastore.py (327 lines of code) (raw):
import asyncio
import base64
import os
import re
import time
from typing import Dict, List, Optional, Union
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential as DefaultAzureCredentialSync
from azure.identity.aio import DefaultAzureCredential
from azure.search.documents.aio import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import *
from azure.search.documents.models import QueryType, Vector
from loguru import logger
from datastore.datastore import DataStore
from models.models import (
DocumentChunk,
DocumentChunkMetadata,
DocumentChunkWithScore,
DocumentMetadataFilter,
Query,
QueryResult,
QueryWithEmbedding,
)
AZURESEARCH_SERVICE = os.environ.get("AZURESEARCH_SERVICE")
AZURESEARCH_INDEX = os.environ.get("AZURESEARCH_INDEX")
AZURESEARCH_API_KEY = os.environ.get("AZURESEARCH_API_KEY")
AZURESEARCH_SEMANTIC_CONFIG = os.environ.get("AZURESEARCH_SEMANTIC_CONFIG")
AZURESEARCH_LANGUAGE = os.environ.get("AZURESEARCH_LANGUAGE", "en-us")
AZURESEARCH_DISABLE_HYBRID = os.environ.get("AZURESEARCH_DISABLE_HYBRID")
AZURESEARCH_DIMENSIONS = os.environ.get(
"AZURESEARCH_DIMENSIONS", 256
) # Default to 256 dimensions, change if using a different embeddings model
assert AZURESEARCH_SERVICE is not None
assert AZURESEARCH_INDEX is not None
# Allow overriding field names for Azure Search
FIELDS_ID = os.environ.get("AZURESEARCH_FIELDS_ID", "id")
FIELDS_TEXT = os.environ.get("AZURESEARCH_FIELDS_TEXT", "text")
FIELDS_EMBEDDING = os.environ.get("AZURESEARCH_FIELDS_EMBEDDING", "embedding")
FIELDS_DOCUMENT_ID = os.environ.get("AZURESEARCH_FIELDS_DOCUMENT_ID", "document_id")
FIELDS_SOURCE = os.environ.get("AZURESEARCH_FIELDS_SOURCE", "source")
FIELDS_SOURCE_ID = os.environ.get("AZURESEARCH_FIELDS_SOURCE_ID", "source_id")
FIELDS_URL = os.environ.get("AZURESEARCH_FIELDS_URL", "url")
FIELDS_CREATED_AT = os.environ.get("AZURESEARCH_FIELDS_CREATED_AT", "created_at")
FIELDS_AUTHOR = os.environ.get("AZURESEARCH_FIELDS_AUTHOR", "author")
MAX_UPLOAD_BATCH_SIZE = 1000
MAX_DELETE_BATCH_SIZE = 1000
class AzureSearchDataStore(DataStore):
def __init__(self):
self.client = SearchClient(
endpoint=f"https://{AZURESEARCH_SERVICE}.search.windows.net",
index_name=AZURESEARCH_INDEX,
credential=AzureSearchDataStore._create_credentials(True),
user_agent="retrievalplugin",
)
mgmt_client = SearchIndexClient(
endpoint=f"https://{AZURESEARCH_SERVICE}.search.windows.net",
credential=AzureSearchDataStore._create_credentials(False),
user_agent="retrievalplugin",
)
if AZURESEARCH_INDEX not in [name for name in mgmt_client.list_index_names()]:
self._create_index(mgmt_client)
else:
logger.info(
f"Using existing index {AZURESEARCH_INDEX} in service {AZURESEARCH_SERVICE}"
)
async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]:
azdocuments: List[Dict] = []
async def upload():
r = await self.client.upload_documents(documents=azdocuments)
count = sum(1 for rr in r if rr.succeeded)
logger.info(f"Upserted {count} chunks out of {len(azdocuments)}")
if count < len(azdocuments):
raise Exception(f"Failed to upload {len(azdocuments) - count} chunks")
ids = []
for document_id, document_chunks in chunks.items():
ids.append(document_id)
for chunk in document_chunks:
azdocuments.append(
{
# base64-encode the id string to stay within Azure Search's valid characters for keys
FIELDS_ID: base64.urlsafe_b64encode(
bytes(chunk.id, "utf-8")
).decode("ascii"),
FIELDS_TEXT: chunk.text,
FIELDS_EMBEDDING: chunk.embedding,
FIELDS_DOCUMENT_ID: document_id,
FIELDS_SOURCE: chunk.metadata.source,
FIELDS_SOURCE_ID: chunk.metadata.source_id,
FIELDS_URL: chunk.metadata.url,
FIELDS_CREATED_AT: chunk.metadata.created_at,
FIELDS_AUTHOR: chunk.metadata.author,
}
)
if len(azdocuments) >= MAX_UPLOAD_BATCH_SIZE:
await upload()
azdocuments = []
if len(azdocuments) > 0:
await upload()
return ids
async def delete(
self,
ids: Optional[List[str]] = None,
filter: Optional[DocumentMetadataFilter] = None,
delete_all: Optional[bool] = None,
) -> bool:
filter = None if delete_all else self._translate_filter(filter)
if delete_all or filter is not None:
deleted = set()
while True:
search_result = await self.client.search(
None,
filter=filter,
top=MAX_DELETE_BATCH_SIZE,
include_total_count=True,
select=FIELDS_ID,
)
if await search_result.get_count() == 0:
break
documents = [
{FIELDS_ID: d[FIELDS_ID]}
async for d in search_result
if d[FIELDS_ID] not in deleted
]
if len(documents) > 0:
logger.info(
f"Deleting {len(documents)} chunks "
+ (
"using a filter"
if filter is not None
else "using delete_all"
)
)
del_result = await self.client.delete_documents(documents=documents)
if not all([rr.succeeded for rr in del_result]):
raise Exception("Failed to delete documents")
deleted.update([d[FIELDS_ID] for d in documents])
else:
# All repeats, delay a bit to let the index refresh and try again
time.sleep(0.25)
if ids is not None and len(ids) > 0:
for id in ids:
logger.info(f"Deleting chunks for document id {id}")
await self.delete(filter=DocumentMetadataFilter(document_id=id))
return True
async def _query(self, queries: List[QueryWithEmbedding]) -> List[QueryResult]:
"""
Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores.
"""
return await asyncio.gather(*(self._single_query(query) for query in queries))
async def _single_query(self, query: QueryWithEmbedding) -> QueryResult:
"""
Takes in a single query and filters and returns a query result with matching document chunks and scores.
"""
filter = (
self._translate_filter(query.filter) if query.filter is not None else None
)
try:
vector_top_k = query.top_k if filter is None else query.top_k * 2
if not AZURESEARCH_DISABLE_HYBRID:
vector_top_k *= 2
q = query.query if not AZURESEARCH_DISABLE_HYBRID else None
vector_q = Vector(
value=query.embedding, k=vector_top_k, fields=FIELDS_EMBEDDING
)
if AZURESEARCH_SEMANTIC_CONFIG != None and not AZURESEARCH_DISABLE_HYBRID:
# Ensure we're feeding a good number of candidates to the L2 reranker
vector_top_k = max(50, vector_top_k)
r = await self.client.search(
q,
filter=filter,
top=query.top_k,
vectors=[vector_q],
query_type=QueryType.SEMANTIC,
query_language=AZURESEARCH_LANGUAGE,
semantic_configuration_name=AZURESEARCH_SEMANTIC_CONFIG,
)
else:
r = await self.client.search(
q, filter=filter, top=query.top_k, vectors=[vector_q]
)
results: List[DocumentChunkWithScore] = []
async for hit in r:
f = lambda field: hit.get(field) if field != "-" else None
results.append(
DocumentChunkWithScore(
id=hit[FIELDS_ID],
text=hit[FIELDS_TEXT],
metadata=DocumentChunkMetadata(
document_id=f(FIELDS_DOCUMENT_ID),
source=f(FIELDS_SOURCE) or "file",
source_id=f(FIELDS_SOURCE_ID),
url=f(FIELDS_URL),
created_at=f(FIELDS_CREATED_AT),
author=f(FIELDS_AUTHOR),
),
score=hit["@search.score"],
)
)
return QueryResult(query=query.query, results=results)
except Exception as e:
raise Exception(f"Error querying the index: {e}")
@staticmethod
def _translate_filter(filter: DocumentMetadataFilter) -> str:
"""
Translates a DocumentMetadataFilter into an Azure Search filter string
"""
if filter is None:
return None
escape = lambda s: s.replace("'", "''")
# regex to validate dates are in OData format
date_re = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")
filter_list = []
if filter.document_id is not None:
filter_list.append(
f"{FIELDS_DOCUMENT_ID} eq '{escape(filter.document_id)}'"
)
if filter.source is not None:
filter_list.append(f"{FIELDS_SOURCE} eq '{escape(filter.source)}'")
if filter.source_id is not None:
filter_list.append(f"{FIELDS_SOURCE_ID} eq '{escape(filter.source_id)}'")
if filter.author is not None:
filter_list.append(f"{FIELDS_AUTHOR} eq '{escape(filter.author)}'")
if filter.start_date is not None:
if not date_re.match(filter.start_date):
raise ValueError(
f"start_date must be in OData format, got {filter.start_date}"
)
filter_list.append(f"{FIELDS_CREATED_AT} ge {filter.start_date}")
if filter.end_date is not None:
if not date_re.match(filter.end_date):
raise ValueError(
f"end_date must be in OData format, got {filter.end_date}"
)
filter_list.append(f"{FIELDS_CREATED_AT} le {filter.end_date}")
return " and ".join(filter_list) if len(filter_list) > 0 else None
def _create_index(self, mgmt_client: SearchIndexClient):
"""
Creates an Azure Cognitive Search index, including a semantic search configuration if a name is specified for it
"""
logger.info(
f"Creating index {AZURESEARCH_INDEX} in service {AZURESEARCH_SERVICE}"
+ (
f" with semantic search configuration {AZURESEARCH_SEMANTIC_CONFIG}"
if AZURESEARCH_SEMANTIC_CONFIG is not None
else ""
)
)
mgmt_client.create_index(
SearchIndex(
name=AZURESEARCH_INDEX,
fields=[
SimpleField(
name=FIELDS_ID, type=SearchFieldDataType.String, key=True
),
SearchableField(
name=FIELDS_TEXT,
type=SearchFieldDataType.String,
analyzer_name="standard.lucene",
),
SearchField(
name=FIELDS_EMBEDDING,
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
hidden=False,
searchable=True,
filterable=False,
sortable=False,
facetable=False,
vector_search_dimensions=AZURESEARCH_DIMENSIONS,
vector_search_configuration="default",
),
SimpleField(
name=FIELDS_DOCUMENT_ID,
type=SearchFieldDataType.String,
filterable=True,
sortable=True,
),
SimpleField(
name=FIELDS_SOURCE,
type=SearchFieldDataType.String,
filterable=True,
sortable=True,
),
SimpleField(
name=FIELDS_SOURCE_ID,
type=SearchFieldDataType.String,
filterable=True,
sortable=True,
),
SimpleField(name=FIELDS_URL, type=SearchFieldDataType.String),
SimpleField(
name=FIELDS_CREATED_AT,
type=SearchFieldDataType.DateTimeOffset,
filterable=True,
sortable=True,
),
SimpleField(
name=FIELDS_AUTHOR,
type=SearchFieldDataType.String,
filterable=True,
sortable=True,
),
],
semantic_settings=None
if AZURESEARCH_SEMANTIC_CONFIG is None
else SemanticSettings(
configurations=[
SemanticConfiguration(
name=AZURESEARCH_SEMANTIC_CONFIG,
prioritized_fields=PrioritizedFields(
title_field=None,
prioritized_content_fields=[
SemanticField(field_name=FIELDS_TEXT)
],
),
)
]
),
vector_search=VectorSearch(
algorithm_configurations=[
HnswVectorSearchAlgorithmConfiguration(
name="default",
kind="hnsw",
# Could change to dotproduct for OpenAI's embeddings since they normalize vectors to unit length
hnsw_parameters=HnswParameters(metric="cosine"),
)
]
),
)
)
@staticmethod
def _create_credentials(
use_async: bool,
) -> Union[AzureKeyCredential, DefaultAzureCredential, DefaultAzureCredentialSync]:
if AZURESEARCH_API_KEY is None:
logger.info(
"Using DefaultAzureCredential for Azure Search, make sure local identity or managed identity are set up appropriately"
)
credential = (
DefaultAzureCredential() if use_async else DefaultAzureCredentialSync()
)
else:
logger.info("Using an API key to authenticate with Azure Search")
credential = AzureKeyCredential(AZURESEARCH_API_KEY)
return credential