code/embedding-function/utilities/helpers/embedders/push_embedder.py (173 lines of code) (raw):
import hashlib
import json
import logging
from typing import List
from urllib.parse import urlparse
from ...helpers.llm_helper import LLMHelper
from ...helpers.env_helper import EnvHelper
from ..azure_computer_vision_client import AzureComputerVisionClient
from ..azure_blob_storage_client import AzureBlobStorageClient
from ..config.embedding_config import EmbeddingConfig
from ..config.config_helper import ConfigHelper
from .embedder_base import EmbedderBase
from ..azure_search_helper import AzureSearchHelper
from ..document_loading_helper import DocumentLoading
from ..document_chunking_helper import DocumentChunking
from ...common.source_document import SourceDocument
logger = logging.getLogger(__name__)
class PushEmbedder(EmbedderBase):
def __init__(self, blob_client: AzureBlobStorageClient, env_helper: EnvHelper):
logger.info("Initializing PushEmbedder")
self.env_helper = env_helper
self.llm_helper = LLMHelper()
self.azure_search_helper = AzureSearchHelper()
self.azure_computer_vision_client = AzureComputerVisionClient(env_helper)
self.document_loading = DocumentLoading()
self.document_chunking = DocumentChunking()
self.blob_client = blob_client
self.config = ConfigHelper.get_active_config_or_default()
self.embedding_configs = {}
logger.info("Loading document processors")
for processor in self.config.document_processors:
ext = processor.document_type.lower()
self.embedding_configs[ext] = processor
logger.info("Document processors loaded")
def embed_file(self, source_url: str, file_name: str, sharepoint_file_id: str):
logger.info(f"Embedding file: {file_name} from URL: {source_url} - sharepoint_file_id: {sharepoint_file_id}")
file_extension = file_name.split(".")[-1].lower()
embedding_config = self.embedding_configs.get(file_extension)
self.__embed(
source_url=source_url,
file_extension=file_extension,
embedding_config=embedding_config,
sharepoint_file_id=sharepoint_file_id
)
if file_extension != "url":
logger.info(f"Upserting blob metadata for file: {file_name}")
self.blob_client.upsert_blob_metadata(
file_name, {"embeddings_added": "true"}
)
def __embed(
self, source_url: str, file_extension: str, embedding_config: EmbeddingConfig, sharepoint_file_id: str
):
logger.info(f"Processing embedding for file extension: {file_extension}")
documents_to_upload: List[SourceDocument] = []
if (
embedding_config.use_advanced_image_processing
and file_extension
in self.config.get_advanced_image_processing_image_types()
):
logger.info(f"Using advanced image processing for: {source_url}")
caption = self.__generate_image_caption(source_url)
caption_vector = self.llm_helper.generate_embeddings(caption)
image_vector = self.azure_computer_vision_client.vectorize_image(source_url)
documents_to_upload.append(
self.__create_image_document(
source_url, image_vector, caption, caption_vector, sharepoint_file_id
)
)
else:
logger.info(f"Loading documents from source: {source_url}")
documents: List[SourceDocument] = self.document_loading.load(
source_url, embedding_config.loading
)
documents = self.document_chunking.chunk(
documents, embedding_config.chunking
)
for document in documents:
documents_to_upload.append(self.__convert_to_search_document(document, sharepoint_file_id))
# Upload documents (which are chunks) to search index in batches
if documents_to_upload:
logger.info("Uploading documents in batches")
batch_size = self.env_helper.AZURE_SEARCH_DOC_UPLOAD_BATCH_SIZE
search_client = self.azure_search_helper.get_search_client()
for i in range(0, len(documents_to_upload), batch_size):
batch = documents_to_upload[i : i + batch_size]
response = search_client.upload_documents(batch)
if not all(r.succeeded for r in response if response):
logger.error("Failed to upload documents to search index")
raise RuntimeError(f"Upload failed for some documents: {response}")
else:
logger.warning("No documents to upload.")
def __generate_image_caption(self, source_url):
logger.info(f"Generating image caption for URL: {source_url}")
model = self.env_helper.AZURE_OPENAI_VISION_MODEL
caption_system_message = """You are an assistant that generates rich descriptions of images.
You need to be accurate in the information you extract and detailed in the descriptons you generate.
Do not abbreviate anything and do not shorten sentances. Explain the image completely.
If you are provided with an image of a flow chart, describe the flow chart in detail.
If the image is mostly text, use OCR to extract the text as it is displayed in the image."""
messages = [
{"role": "system", "content": caption_system_message},
{
"role": "user",
"content": [
{
"text": "Describe this image in detail. Limit the response to 500 words.",
"type": "text",
},
{"image_url": {"url": source_url}, "type": "image_url"},
],
},
]
response = self.llm_helper.get_chat_completion(messages, model)
caption = response.choices[0].message.content
logger.info("Caption generation completed")
return caption
def __convert_to_search_document(self, document: SourceDocument, sharepoint_file_id: str):
logger.info(f"Converting document ID {document.id} to search document format")
embedded_content = self.llm_helper.generate_embeddings(document.content)
metadata = {
self.env_helper.AZURE_SEARCH_FIELDS_ID: document.id,
self.env_helper.AZURE_SEARCH_SOURCE_COLUMN: document.source,
self.env_helper.AZURE_SEARCH_TITLE_COLUMN: document.title,
self.env_helper.AZURE_SEARCH_CHUNK_COLUMN: document.chunk,
self.env_helper.AZURE_SEARCH_OFFSET_COLUMN: document.offset,
"page_number": document.page_number,
"chunk_id": document.chunk_id,
self.env_helper.AZURE_SEARCH_SHAREPOINT_FILE_ID_COLUMN: sharepoint_file_id
}
return {
self.env_helper.AZURE_SEARCH_FIELDS_ID: document.id,
self.env_helper.AZURE_SEARCH_CONTENT_COLUMN: document.content,
self.env_helper.AZURE_SEARCH_CONTENT_VECTOR_COLUMN: embedded_content,
self.env_helper.AZURE_SEARCH_FIELDS_METADATA: json.dumps(metadata),
self.env_helper.AZURE_SEARCH_TITLE_COLUMN: document.title,
self.env_helper.AZURE_SEARCH_SOURCE_COLUMN: document.source,
self.env_helper.AZURE_SEARCH_CHUNK_COLUMN: document.chunk,
self.env_helper.AZURE_SEARCH_OFFSET_COLUMN: document.offset,
self.env_helper.AZURE_SEARCH_SHAREPOINT_FILE_ID_COLUMN: sharepoint_file_id
}
def __generate_document_id(self, source_url: str) -> str:
hash_key = hashlib.sha1(f"{source_url}_1".encode("utf-8")).hexdigest()
return f"doc_{hash_key}"
def __create_image_document(
self,
source_url: str,
image_vector: List[float],
content: str,
content_vector: List[float],
sharepoint_file_id :str
):
logger.info(f"Creating image document for source URL: {source_url}")
parsed_url = urlparse(source_url)
file_url = parsed_url.scheme + "://" + parsed_url.netloc + parsed_url.path
document_id = self.__generate_document_id(file_url)
filename = parsed_url.path
sas_placeholder = (
"_SAS_TOKEN_PLACEHOLDER_"
if parsed_url.netloc
and parsed_url.netloc.endswith(".blob.core.windows.net")
else ""
)
return {
"id": document_id,
"content": content,
"content_vector": content_vector,
"image_vector": image_vector,
"metadata": json.dumps(
{
"id": document_id,
"title": filename,
"source": file_url + sas_placeholder,
}
),
"title": filename,
"source": file_url + sas_placeholder,
"sharepoint_file_id": sharepoint_file_id
}