code/embedding-function/utilities/helpers/llm_helper.py (132 lines of code) (raw):
import logging
from openai import AzureOpenAI
from typing import List, Union
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential
from .env_helper import EnvHelper
logger = logging.getLogger(__name__)
class LLMHelper:
def __init__(self):
logger.info("Initializing LLMHelper")
self.env_helper: EnvHelper = EnvHelper()
self.auth_type_keys = self.env_helper.is_auth_type_keys()
self.token_provider = self.env_helper.AZURE_TOKEN_PROVIDER
logger.info(self.auth_type_keys)
logger.info(self.env_helper.OPENAI_API_KEY)
if self.auth_type_keys:
self.openai_client = AzureOpenAI(
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_version=self.env_helper.AZURE_OPENAI_API_VERSION,
api_key=self.env_helper.OPENAI_API_KEY,
)
else:
self.openai_client = AzureOpenAI(
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_version=self.env_helper.AZURE_OPENAI_API_VERSION,
azure_ad_token_provider=self.token_provider,
)
self.llm_model = self.env_helper.AZURE_OPENAI_MODEL
self.llm_max_tokens = (
int(self.env_helper.AZURE_OPENAI_MAX_TOKENS)
if self.env_helper.AZURE_OPENAI_MAX_TOKENS != ""
else None
)
self.embedding_model = self.env_helper.AZURE_OPENAI_EMBEDDING_MODEL
logger.info(self.openai_client)
logger.info(self.embedding_model)
logger.info("Initializing LLMHelper completed")
def get_llm(self):
if self.auth_type_keys:
return AzureChatOpenAI(
deployment_name=self.llm_model,
temperature=0,
max_tokens=self.llm_max_tokens,
openai_api_version=self.openai_client._api_version,
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_key=self.env_helper.OPENAI_API_KEY,
)
else:
return AzureChatOpenAI(
deployment_name=self.llm_model,
temperature=0,
max_tokens=self.llm_max_tokens,
openai_api_version=self.openai_client._api_version,
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
azure_ad_token_provider=self.token_provider,
)
# TODO: This needs to have a custom callback to stream back to the UI
def get_streaming_llm(self):
if self.auth_type_keys:
return AzureChatOpenAI(
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_key=self.env_helper.OPENAI_API_KEY,
streaming=True,
callbacks=[StreamingStdOutCallbackHandler],
deployment_name=self.llm_model,
temperature=0,
max_tokens=self.llm_max_tokens,
openai_api_version=self.openai_client._api_version,
)
else:
return AzureChatOpenAI(
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_key=self.env_helper.OPENAI_API_KEY,
streaming=True,
callbacks=[StreamingStdOutCallbackHandler],
deployment_name=self.llm_model,
temperature=0,
max_tokens=self.llm_max_tokens,
openai_api_version=self.openai_client._api_version,
azure_ad_token_provider=self.token_provider,
)
def get_embedding_model(self):
if self.auth_type_keys:
return AzureOpenAIEmbeddings(
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
api_key=self.env_helper.OPENAI_API_KEY,
azure_deployment=self.embedding_model,
chunk_size=1,
)
else:
return AzureOpenAIEmbeddings(
azure_endpoint=self.env_helper.AZURE_OPENAI_ENDPOINT,
azure_deployment=self.embedding_model,
chunk_size=1,
azure_ad_token_provider=self.token_provider,
)
def generate_embeddings(self, input: Union[str, list[int]]) -> List[float]:
return (
self.openai_client.embeddings.create(
input=[input], model=self.embedding_model
)
.data[0]
.embedding
)
def get_chat_completion_with_functions(
self, messages: list[dict], functions: list[dict], function_call: str = "auto"
):
return self.openai_client.chat.completions.create(
model=self.llm_model,
messages=messages,
functions=functions,
function_call=function_call,
)
def get_chat_completion(
self, messages: list[dict], model: str | None = None, **kwargs
):
return self.openai_client.chat.completions.create(
model=model or self.llm_model,
messages=messages,
max_tokens=self.llm_max_tokens,
**kwargs
)
def get_ml_client(self):
if not hasattr(self, "_ml_client"):
self._ml_client = MLClient(
DefaultAzureCredential(),
self.env_helper.AZURE_SUBSCRIPTION_ID,
self.env_helper.AZURE_RESOURCE_GROUP,
self.env_helper.AZURE_ML_WORKSPACE_NAME,
)
return self._ml_client