code/embedding-function/utilities/helpers/config/config_helper.py (242 lines of code) (raw):

import os import json import logging import functools from ..azure_blob_storage_client import AzureBlobStorageClient from ...document_chunking.chunking_strategy import ChunkingStrategy, ChunkingSettings from ...document_loading import LoadingSettings, LoadingStrategy from .embedding_config import EmbeddingConfig from ..env_helper import EnvHelper from .assistant_strategy import AssistantStrategy from .conversation_flow import ConversationFlow CONFIG_CONTAINER_NAME = "config" CONFIG_FILE_NAME = "active.json" ADVANCED_IMAGE_PROCESSING_FILE_TYPES = ["jpeg", "jpg", "png", "tiff", "bmp"] logger = logging.getLogger(__name__) class Config: def __init__(self, config: dict): self.prompts = Prompts(config["prompts"]) self.messages = Messages(config["messages"]) self.example = Example(config["example"]) self.logging = Logging(config["logging"]) self.document_processors = [ EmbeddingConfig( document_type=c["document_type"], chunking=ChunkingSettings(c["chunking"]), loading=LoadingSettings(c["loading"]), use_advanced_image_processing=c.get( "use_advanced_image_processing", False ), ) for c in config["document_processors"] ] self.env_helper = EnvHelper() self.integrated_vectorization_config = ( IntegratedVectorizationConfig(config["integrated_vectorization_config"]) if self.env_helper.AZURE_SEARCH_USE_INTEGRATED_VECTORIZATION else None ) def get_available_document_types(self) -> list[str]: document_types = { "txt", "pdf", "url", "html", "htm", "md", "jpeg", "jpg", "png", "docx", } if self.env_helper.USE_ADVANCED_IMAGE_PROCESSING: document_types.update(ADVANCED_IMAGE_PROCESSING_FILE_TYPES) return sorted(document_types) def get_advanced_image_processing_image_types(self): return ADVANCED_IMAGE_PROCESSING_FILE_TYPES def get_available_chunking_strategies(self): return [c.value for c in ChunkingStrategy] def get_available_loading_strategies(self): return [c.value for c in LoadingStrategy] def get_available_ai_assistant_types(self): return [c.value for c in AssistantStrategy] def get_available_conversational_flows(self): return [c.value for c in ConversationFlow] # TODO: Change to AnsweringChain or something, Prompts is not a good name class Prompts: def __init__(self, prompts: dict): self.condense_question_prompt = prompts["condense_question_prompt"] self.answering_system_prompt = prompts["answering_system_prompt"] self.answering_user_prompt = prompts["answering_user_prompt"] self.post_answering_prompt = prompts["post_answering_prompt"] self.use_on_your_data_format = prompts["use_on_your_data_format"] self.enable_post_answering_prompt = prompts["enable_post_answering_prompt"] self.enable_content_safety = prompts["enable_content_safety"] self.ai_assistant_type = prompts["ai_assistant_type"] self.conversational_flow = prompts["conversational_flow"] class Example: def __init__(self, example: dict): self.documents = example["documents"] self.user_question = example["user_question"] self.answer = example["answer"] class Messages: def __init__(self, messages: dict): self.post_answering_filter = messages["post_answering_filter"] class Logging: def __init__(self, logging: dict): self.log_user_interactions = ( str(logging["log_user_interactions"]).lower() == "true" ) self.log_tokens = str(logging["log_tokens"]).lower() == "true" class IntegratedVectorizationConfig: def __init__(self, integrated_vectorization_config: dict): self.max_page_length = integrated_vectorization_config["max_page_length"] self.page_overlap_length = integrated_vectorization_config[ "page_overlap_length" ] class ConfigHelper: _default_config = None @staticmethod def _set_new_config_properties(config: dict, default_config: dict): """ Function used to set newer properties that will not be present in older configs. The function mutates the config object. """ if config["prompts"].get("answering_system_prompt") is None: config["prompts"]["answering_system_prompt"] = default_config["prompts"][ "answering_system_prompt" ] prompt_modified = ( config["prompts"].get("answering_prompt") != default_config["prompts"]["answering_prompt"] ) if config["prompts"].get("answering_user_prompt") is None: if prompt_modified: config["prompts"]["answering_user_prompt"] = config["prompts"].get( "answering_prompt" ) else: config["prompts"]["answering_user_prompt"] = default_config["prompts"][ "answering_user_prompt" ] if config["prompts"].get("use_on_your_data_format") is None: config["prompts"]["use_on_your_data_format"] = not prompt_modified if config.get("example") is None: config["example"] = default_config["example"] if config["prompts"].get("ai_assistant_type") is None: config["prompts"]["ai_assistant_type"] = default_config["prompts"][ "ai_assistant_type" ] if config.get("integrated_vectorization_config") is None: config["integrated_vectorization_config"] = default_config[ "integrated_vectorization_config" ] if config["prompts"].get("conversational_flow") is None: config["prompts"]["conversational_flow"] = default_config["prompts"][ "conversational_flow" ] if config.get("enable_chat_history") is None: config["enable_chat_history"] = default_config["enable_chat_history"] @staticmethod @functools.cache def get_active_config_or_default(): logger.info("Method get_active_config_or_default started") env_helper = EnvHelper() config = ConfigHelper.get_default_config() if env_helper.LOAD_CONFIG_FROM_BLOB_STORAGE: logger.info("Loading configuration from Blob Storage") blob_client = AzureBlobStorageClient(container_name=CONFIG_CONTAINER_NAME) if blob_client.file_exists(CONFIG_FILE_NAME): logger.info("Configuration file found in Blob Storage") default_config = config config_file = blob_client.download_file(CONFIG_FILE_NAME) config = json.loads(config_file) ConfigHelper._set_new_config_properties(config, default_config) else: logger.info( "Configuration file not found in Blob Storage, using default configuration" ) logger.info("Method get_active_config_or_default ended") return Config(config) @staticmethod @functools.cache def get_default_assistant_prompt(): config = ConfigHelper.get_default_config() return config["prompts"]["answering_user_prompt"] @staticmethod def save_config_as_active(config): ConfigHelper.validate_config(config) blob_client = AzureBlobStorageClient(container_name=CONFIG_CONTAINER_NAME) blob_client = blob_client.upload_file( json.dumps(config, indent=2), CONFIG_FILE_NAME, content_type="application/json", ) ConfigHelper.get_active_config_or_default.cache_clear() @staticmethod def validate_config(config: dict): for document_processor in config.get("document_processors"): document_type = document_processor.get("document_type") unsupported_advanced_image_processing_file_type = ( document_type not in ADVANCED_IMAGE_PROCESSING_FILE_TYPES ) if ( document_processor.get("use_advanced_image_processing") and unsupported_advanced_image_processing_file_type ): raise Exception( f"Advanced image processing has not been enabled for document type {document_type}, as only {ADVANCED_IMAGE_PROCESSING_FILE_TYPES} file types are supported." ) @staticmethod def get_default_config(): if ConfigHelper._default_config is None: env_helper = EnvHelper() config_file_path = os.path.join(os.path.dirname(__file__), "default.json") with open(config_file_path, encoding="utf-8") as f: logger.info("Loading default config from %s", config_file_path) ConfigHelper._default_config = json.load(f) if env_helper.USE_ADVANCED_IMAGE_PROCESSING: ConfigHelper._append_advanced_image_processors() return ConfigHelper._default_config @staticmethod @functools.cache def get_default_contract_assistant(): contract_file_path = os.path.join( os.path.dirname(__file__), "default_contract_assistant_prompt.txt" ) contract_assistant = "" with open(contract_file_path, encoding="utf-8") as f: contract_assistant = f.readlines() return "".join([str(elem) for elem in contract_assistant]) @staticmethod @functools.cache def get_default_employee_assistant(): employee_file_path = os.path.join( os.path.dirname(__file__), "default_employee_assistant_prompt.txt" ) employee_assistant = "" with open(employee_file_path, encoding="utf-8") as f: employee_assistant = f.readlines() return "".join([str(elem) for elem in employee_assistant]) @staticmethod def clear_config(): ConfigHelper._default_config = None ConfigHelper.get_active_config_or_default.cache_clear() @staticmethod def _append_advanced_image_processors(): image_file_types = ["jpeg", "jpg", "png", "tiff", "bmp"] ConfigHelper._remove_processors_for_file_types(image_file_types) ConfigHelper._default_config["document_processors"].extend( [ {"document_type": file_type, "use_advanced_image_processing": True} for file_type in image_file_types ] ) @staticmethod def _remove_processors_for_file_types(file_types: list[str]): document_processors = ConfigHelper._default_config["document_processors"] document_processors = [ document_processor for document_processor in document_processors if document_processor["document_type"] not in file_types ] ConfigHelper._default_config["document_processors"] = document_processors @staticmethod def delete_config(): blob_client = AzureBlobStorageClient(container_name=CONFIG_CONTAINER_NAME) blob_client.delete_file(CONFIG_FILE_NAME)