def get_llm_pool()

in 3_optimization-design-ptn/03_prompt-optimization/promptwizard/glue/common/llm/llm_mgr.py [0:0]


    def get_llm_pool(llm_config: LLMConfig) -> Dict[str, LLM]:
        """
        Create a dictionary of LLMs. key would be unique id of LLM, value is object using which
        methods associated with that LLM service can be called.

        :param llm_config: Object having all settings & preferences for all LLMs to be used in out system
        :return: Dict key=unique_model_id of LLM, value=Object of class llama_index.core.llms.LLM
        which can be used as handle to that LLM
        """
        llm_pool = {}
        az_llm_config = llm_config.azure_open_ai

        if az_llm_config:
            install_lib_if_missing(InstallLibs.LLAMA_LLM_AZ_OAI)
            install_lib_if_missing(InstallLibs.LLAMA_EMB_AZ_OAI)
            install_lib_if_missing(InstallLibs.LLAMA_MM_LLM_AZ_OAI)
            install_lib_if_missing(InstallLibs.TIKTOKEN)

            import tiktoken

            # from llama_index.llms.azure_openai import AzureOpenAI
            from openai import AzureOpenAI
            from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
            from llama_index.multi_modal_llms.azure_openai import AzureOpenAIMultiModal

            az_token_provider = None
            # if az_llm_config.use_azure_ad:
            from azure.identity import get_bearer_token_provider, AzureCliCredential

            az_token_provider = get_bearer_token_provider(
                AzureCliCredential(), "https://cognitiveservices.azure.com/.default"
            )

            for azure_oai_model in az_llm_config.azure_oai_models:
                callback_mgr = None
                if azure_oai_model.track_tokens:

                    # If we need to count number of tokens used in LLM calls
                    token_counter = TokenCountingHandler(
                        tokenizer=tiktoken.encoding_for_model(
                            azure_oai_model.model_name_in_azure
                        ).encode
                    )
                    callback_mgr = CallbackManager([token_counter])
                    token_counter.reset_counts()
                    # ()

                if azure_oai_model.model_type in [
                    LLMOutputTypes.CHAT,
                    LLMOutputTypes.COMPLETION,
                ]:
                    # ()
                    llm_pool[azure_oai_model.unique_model_id] = AzureOpenAI(
                        # use_azure_ad=az_llm_config.use_azure_ad,
                        azure_ad_token_provider=az_token_provider,
                        # model=azure_oai_model.model_name_in_azure,
                        # deployment_name=azure_oai_model.deployment_name_in_azure,
                        api_key=az_llm_config.api_key,
                        azure_endpoint=az_llm_config.azure_endpoint,
                        api_version=az_llm_config.api_version,
                        # callback_manager=callback_mgr
                    )
                    # ()
                elif azure_oai_model.model_type == LLMOutputTypes.EMBEDDINGS:
                    llm_pool[azure_oai_model.unique_model_id] = AzureOpenAIEmbedding(
                        use_azure_ad=az_llm_config.use_azure_ad,
                        azure_ad_token_provider=az_token_provider,
                        model=azure_oai_model.model_name_in_azure,
                        deployment_name=azure_oai_model.deployment_name_in_azure,
                        api_key=az_llm_config.api_key,
                        azure_endpoint=az_llm_config.azure_endpoint,
                        api_version=az_llm_config.api_version,
                        callback_manager=callback_mgr,
                    )
                elif azure_oai_model.model_type == LLMOutputTypes.MULTI_MODAL:

                    llm_pool[azure_oai_model.unique_model_id] = AzureOpenAIMultiModal(
                        use_azure_ad=az_llm_config.use_azure_ad,
                        azure_ad_token_provider=az_token_provider,
                        model=azure_oai_model.model_name_in_azure,
                        deployment_name=azure_oai_model.deployment_name_in_azure,
                        api_key=az_llm_config.api_key,
                        azure_endpoint=az_llm_config.azure_endpoint,
                        api_version=az_llm_config.api_version,
                        max_new_tokens=4096,
                    )

        if llm_config.custom_models:
            for custom_model in llm_config.custom_models:
                # try:
                custom_llm_class = str_to_class(
                    custom_model.class_name, None, custom_model.path_to_py_file
                )

                callback_mgr = None
                if custom_model.track_tokens:
                    # If we need to count number of tokens used in LLM calls
                    token_counter = TokenCountingHandler(
                        tokenizer=custom_llm_class.get_tokenizer()
                    )
                    callback_mgr = CallbackManager([token_counter])
                    token_counter.reset_counts()
                llm_pool[custom_model.unique_model_id] = custom_llm_class(
                    callback_manager=callback_mgr
                )
                # except Exception as e:
                # raise GlueLLMException(f"Custom model {custom_model.unique_model_id} not loaded.", e)
        return llm_pool