backend-apis/app/utils/utils_gemini.py (69 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Utils module for Vertex AI Gemini """ import asyncio import functools from google.api_core.exceptions import GoogleAPICallError from vertexai.generative_models._generative_models import GenerationResponse from vertexai.preview.generative_models import GenerativeModel gemini_pro_vision = GenerativeModel("gemini-1.5-pro") gemini_pro_text = GenerativeModel("gemini-1.5-pro") def generate_gemini_pro_vision(contents: list) -> GenerationResponse: """ Args: contents: Returns: """ response = gemini_pro_vision.generate_content( contents=contents ) if isinstance(response, GenerationResponse): return response response_iterator = iter(response) return next(response_iterator) def generate_gemini_pro_text( prompt: str, max_output_tokens: int = 2048, temperature: float = 0.2, top_k: int = 40, top_p: float = 0.9, candidate_count: int = 1 ) -> str: """ Args: prompt: Returns: LLM response """ response = gemini_pro_text.generate_content( contents=[prompt], generation_config={ "temperature": temperature, "top_p": top_p, "top_k": top_k, "candidate_count": candidate_count, "max_output_tokens": max_output_tokens, } ) return response.text async def async_predict_text_llm( prompt: str, max_output_tokens: int = 1024, temperature: float = 0.2, top_k: int = 40, top_p: float = 0.8, ) -> str: """ Args: model: prompt: max_output_tokens: temperature: top_k: top_p: Returns: """ loop = asyncio.get_running_loop() generated_response = None try: """ generate_content( contents: content_types.ContentsType, *, generation_config: (generation_types.GenerationConfigType | None) = None, safety_settings: (safety_types.SafetySettingOptions | None) = None, stream: bool = False, tools: (content_types.FunctionLibraryType | None) = None, tool_config: (content_types.ToolConfigType | None) = None, request_options: (helper_types.RequestOptionsType | None) = None ) -> generation_types.GenerateContentResponse """ generated_response = await loop.run_in_executor( None, functools.partial( gemini_pro_text.generate_content, contents=prompt ), ) except GoogleAPICallError as e: print(e) return "" if generated_response and generated_response.text: generated_response = generated_response.text.replace("```json", "") generated_response = generated_response.replace("```JSON", "") generated_response = generated_response.replace("```", "") return generated_response return "" async def run_predict_text_llm( prompts: list, temperature: float = 0.2 ) -> list: """ Args: prompts: model: temperature: Returns: """ tasks = [ async_predict_text_llm(prompt=prompt, temperature=temperature) for prompt in prompts ] results = await asyncio.gather(*tasks) return results