genai-on-vertex-ai/gemini/needle_in_a_haystack/needlehaystack/providers/google.py (53 lines of code) (raw):

import os import pkg_resources import requests from typing import Optional import sentencepiece import vertexai from google.cloud.aiplatform_v1 import HarmCategory from vertexai.generative_models import GenerativeModel, HarmBlockThreshold from .model import ModelProvider class Google(ModelProvider): """ A wrapper class for interacting with Google's Gemini API, providing methods to encode text, generate prompts, evaluate models, and create LangChain runnables for language model interactions. Attributes: model_name (str): The name of the Google model to use for evaluations and interactions. model: An instance of the Google Gemini client for API calls. tokenizer: A tokenizer instance for encoding and decoding text to and from token representations. """ DEFAULT_MODEL_KWARGS: dict = dict(max_output_tokens=300, temperature=0) VOCAB_FILE_URL = "https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model" def __init__(self, project_id: str, model_name: str = "gemini-2.0-flash-001", model_kwargs: dict = DEFAULT_MODEL_KWARGS, vocab_file_url: str = VOCAB_FILE_URL): """ Initializes the Google model provider with a specific model. Args: project_id (str): ID of the google cloud platform project to use model_name (str): The name of the Google model to use. Defaults to 'gemini-2.0-flash-001'. model_kwargs (dict): Model configuration. Defaults to {max_tokens: 300, temperature: 0}. vocab_file_url (str): Sentencepiece model file that defines tokenization vocabulary. Deafults to gemma tokenizer https://github.com/google/gemma_pytorch/blob/main/tokenizer/tokenizer.model """ self.model_name = model_name self.model_kwargs = model_kwargs vertexai.init(project=project_id, location="us-central1") self.model = GenerativeModel(self.model_name) local_vocab_file = 'tokenizer.model' if not os.path.exists(local_vocab_file): response = requests.get(vocab_file_url) # Download Tokenizer Vocab File (4MB) response.raise_for_status() with open(local_vocab_file, 'wb') as f: for chunk in response.iter_content(): f.write(chunk) self.tokenizer = sentencepiece.SentencePieceProcessor(local_vocab_file) resource_path = pkg_resources.resource_filename('needlehaystack', 'providers/gemini_prompt.txt') # Generate the prompt structure for the model # Replace the following file with the appropriate prompt structure with open(resource_path, 'r') as file: self.prompt_structure = file.read() async def evaluate_model(self, prompt: str) -> str: """ Evaluates a given prompt using the Google model and retrieves the model's response. Args: prompt (str): The prompt to send to the model. Returns: str: The content of the model's response to the prompt. """ response = await self.model.generate_content_async( prompt, generation_config=self.model_kwargs, safety_settings={ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, } ) return response.text def generate_prompt(self, context: str, retrieval_question: str) -> str: """ Generates a structured prompt for querying the model, based on a given context and retrieval question. Args: context (str): The context or background information relevant to the question. retrieval_question (str): The specific question to be answered by the model. Returns: str: The text prompt """ return self.prompt_structure.format( retrieval_question=retrieval_question, context=context) def encode_text_to_tokens(self, text: str) -> list[int]: """ Encodes a given text string to a sequence of tokens using the model's tokenizer. Args: text (str): The text to encode. Returns: list[int]: A list of token IDs representing the encoded text. """ return self.tokenizer.encode(text) def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str: """ Decodes a sequence of tokens back into a text string using the model's tokenizer. Args: tokens (list[int]): The sequence of token IDs to decode. context_length (Optional[int], optional): An optional length specifying the number of tokens to decode. If not provided, decodes all tokens. Returns: str: The decoded text string. """ return self.tokenizer.decode(tokens[:context_length])