code/embedding-function/utilities/helpers/azure_search_helper.py (265 lines of code) (raw):
import logging
from typing import Union
from langchain_community.vectorstores import AzureSearch
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
ExhaustiveKnnAlgorithmConfiguration,
ExhaustiveKnnParameters,
HnswAlgorithmConfiguration,
HnswParameters,
SearchableField,
SearchField,
SearchFieldDataType,
SearchIndex,
SemanticConfiguration,
SemanticField,
SemanticPrioritizedFields,
SemanticSearch,
SimpleField,
VectorSearch,
VectorSearchAlgorithmKind,
VectorSearchAlgorithmMetric,
VectorSearchProfile,
)
from ..helpers.azure_computer_vision_client import AzureComputerVisionClient
from .llm_helper import LLMHelper
from .env_helper import EnvHelper
logger = logging.getLogger(__name__)
class AzureSearchHelper:
_search_dimension: int | None = None
_image_search_dimension: int | None = None
def __init__(self):
self.llm_helper = LLMHelper()
self.env_helper = EnvHelper()
search_credential = self._search_credential()
self.search_client = self._create_search_client(search_credential)
self.search_index_client = self._create_search_index_client(search_credential)
self.azure_computer_vision_client = AzureComputerVisionClient(self.env_helper)
def _search_credential(self):
if self.env_helper.is_auth_type_keys():
return AzureKeyCredential(self.env_helper.AZURE_SEARCH_KEY)
else:
return DefaultAzureCredential()
def _create_search_client(
self, search_credential: Union[AzureKeyCredential, DefaultAzureCredential]
) -> SearchClient:
return SearchClient(
endpoint=self.env_helper.AZURE_SEARCH_SERVICE,
index_name=self.env_helper.AZURE_SEARCH_INDEX,
credential=search_credential,
)
def _create_search_index_client(
self, search_credential: Union[AzureKeyCredential, DefaultAzureCredential]
):
return SearchIndexClient(
endpoint=self.env_helper.AZURE_SEARCH_SERVICE, credential=search_credential
)
def get_search_client(self) -> SearchClient:
self.create_index()
return self.search_client
@property
def search_dimensions(self) -> int:
if AzureSearchHelper._search_dimension is None:
AzureSearchHelper._search_dimension = len(
self.llm_helper.get_embedding_model().embed_query("Text")
)
return AzureSearchHelper._search_dimension
@property
def image_search_dimensions(self) -> int:
if AzureSearchHelper._image_search_dimension is None:
AzureSearchHelper._image_search_dimension = len(
self.azure_computer_vision_client.vectorize_text("Text")
)
return AzureSearchHelper._image_search_dimension
def create_index(self):
fields = [
SimpleField(
name=self.env_helper.AZURE_SEARCH_FIELDS_ID,
type=SearchFieldDataType.String,
key=True,
filterable=True,
),
SearchableField(
name=self.env_helper.AZURE_SEARCH_CONTENT_COLUMN,
type=SearchFieldDataType.String,
),
SearchField(
name=self.env_helper.AZURE_SEARCH_CONTENT_VECTOR_COLUMN,
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
vector_search_dimensions=self.search_dimensions,
vector_search_profile_name="myHnswProfile",
),
SearchableField(
name=self.env_helper.AZURE_SEARCH_FIELDS_METADATA,
type=SearchFieldDataType.String,
),
SearchableField(
name=self.env_helper.AZURE_SEARCH_TITLE_COLUMN,
type=SearchFieldDataType.String,
facetable=True,
filterable=True,
),
SearchableField(
name=self.env_helper.AZURE_SEARCH_SOURCE_COLUMN,
type=SearchFieldDataType.String,
filterable=True,
),
SimpleField(
name=self.env_helper.AZURE_SEARCH_CHUNK_COLUMN,
type=SearchFieldDataType.Int32,
filterable=True,
),
SimpleField(
name=self.env_helper.AZURE_SEARCH_OFFSET_COLUMN,
type=SearchFieldDataType.Int32,
filterable=True,
),
SearchableField(
name=self.env_helper.AZURE_SEARCH_SHAREPOINT_FILE_ID_COLUMN,
type=SearchFieldDataType.String,
filterable=True,
),
]
if self.env_helper.USE_ADVANCED_IMAGE_PROCESSING:
logger.info("Adding image_vector field to index")
fields.append(
SearchField(
name="image_vector",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
vector_search_dimensions=self.image_search_dimensions,
vector_search_profile_name="myHnswProfile",
),
)
index = SearchIndex(
name=self.env_helper.AZURE_SEARCH_INDEX,
fields=fields,
semantic_search=(
SemanticSearch(
configurations=[
SemanticConfiguration(
name=self.env_helper.AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG,
prioritized_fields=SemanticPrioritizedFields(
title_field=None,
content_fields=[
SemanticField(
field_name=self.env_helper.AZURE_SEARCH_CONTENT_COLUMN
)
],
),
)
]
)
),
vector_search=VectorSearch(
algorithms=[
HnswAlgorithmConfiguration(
name="default",
parameters=HnswParameters(
metric=VectorSearchAlgorithmMetric.COSINE
),
kind=VectorSearchAlgorithmKind.HNSW,
),
ExhaustiveKnnAlgorithmConfiguration(
name="default_exhaustive_knn",
kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN,
parameters=ExhaustiveKnnParameters(
metric=VectorSearchAlgorithmMetric.COSINE
),
),
],
profiles=[
VectorSearchProfile(
name="myHnswProfile",
algorithm_configuration_name="default",
),
VectorSearchProfile(
name="myExhaustiveKnnProfile",
algorithm_configuration_name="default_exhaustive_knn",
),
],
),
)
if self._index_not_exists(self.env_helper.AZURE_SEARCH_INDEX):
logger.info(
f"Creating or updating index {self.env_helper.AZURE_SEARCH_INDEX}"
)
self.search_index_client.create_index(index)
def _index_not_exists(self, index_name: str) -> bool:
return index_name not in [
name for name in self.search_index_client.list_index_names()
]
def get_conversation_logger(self):
fields = [
SimpleField(
name="id",
type=SearchFieldDataType.String,
key=True,
filterable=True,
),
SimpleField(
name="conversation_id",
type=SearchFieldDataType.String,
filterable=True,
facetable=True,
),
SearchableField(
name="content",
type=SearchFieldDataType.String,
),
SearchField(
name="content_vector",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
vector_search_dimensions=self.search_dimensions,
vector_search_profile_name="myHnswProfile",
),
SearchableField(
name="metadata",
type=SearchFieldDataType.String,
),
SimpleField(
name="type",
type=SearchFieldDataType.String,
facetable=True,
filterable=True,
),
SimpleField(
name="user_id",
type=SearchFieldDataType.String,
filterable=True,
facetable=True,
),
SimpleField(
name="sources",
type=SearchFieldDataType.Collection(SearchFieldDataType.String),
filterable=True,
facetable=True,
),
SimpleField(
name="created_at",
type=SearchFieldDataType.DateTimeOffset,
filterable=True,
),
SimpleField(
name="updated_at",
type=SearchFieldDataType.DateTimeOffset,
filterable=True,
),
]
return AzureSearch(
azure_search_endpoint=self.env_helper.AZURE_SEARCH_SERVICE,
azure_search_key=(
self.env_helper.AZURE_SEARCH_KEY
if self.env_helper.is_auth_type_keys()
else None
),
index_name=self.env_helper.AZURE_SEARCH_CONVERSATIONS_LOG_INDEX,
embedding_function=self.llm_helper.get_embedding_model().embed_query,
fields=fields,
user_agent="langchain chatwithyourdata-sa",
)