genai-on-vertex-ai/gemini/needle_in_a_haystack/needlehaystack/providers/google.py (53 lines of code) (raw):
import os
import pkg_resources
import requests
from typing import Optional
import sentencepiece
import vertexai
from google.cloud.aiplatform_v1 import HarmCategory
from vertexai.generative_models import GenerativeModel, HarmBlockThreshold
from .model import ModelProvider
class Google(ModelProvider):
"""
A wrapper class for interacting with Google's Gemini API, providing methods to encode text, generate prompts,
evaluate models, and create LangChain runnables for language model interactions.
Attributes:
model_name (str): The name of the Google model to use for evaluations and interactions.
model: An instance of the Google Gemini client for API calls.
tokenizer: A tokenizer instance for encoding and decoding text to and from token representations.
"""
DEFAULT_MODEL_KWARGS: dict = dict(max_output_tokens=300,
temperature=0)
VOCAB_FILE_URL = "https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model"
def __init__(self,
project_id: str,
model_name: str = "gemini-2.0-flash-001",
model_kwargs: dict = DEFAULT_MODEL_KWARGS,
vocab_file_url: str = VOCAB_FILE_URL):
"""
Initializes the Google model provider with a specific model.
Args:
project_id (str): ID of the google cloud platform project to use
model_name (str): The name of the Google model to use. Defaults to 'gemini-2.0-flash-001'.
model_kwargs (dict): Model configuration. Defaults to {max_tokens: 300, temperature: 0}.
vocab_file_url (str): Sentencepiece model file that defines tokenization vocabulary. Deafults to gemma
tokenizer https://github.com/google/gemma_pytorch/blob/main/tokenizer/tokenizer.model
"""
self.model_name = model_name
self.model_kwargs = model_kwargs
vertexai.init(project=project_id, location="us-central1")
self.model = GenerativeModel(self.model_name)
local_vocab_file = 'tokenizer.model'
if not os.path.exists(local_vocab_file):
response = requests.get(vocab_file_url) # Download Tokenizer Vocab File (4MB)
response.raise_for_status()
with open(local_vocab_file, 'wb') as f:
for chunk in response.iter_content():
f.write(chunk)
self.tokenizer = sentencepiece.SentencePieceProcessor(local_vocab_file)
resource_path = pkg_resources.resource_filename('needlehaystack', 'providers/gemini_prompt.txt')
# Generate the prompt structure for the model
# Replace the following file with the appropriate prompt structure
with open(resource_path, 'r') as file:
self.prompt_structure = file.read()
async def evaluate_model(self, prompt: str) -> str:
"""
Evaluates a given prompt using the Google model and retrieves the model's response.
Args:
prompt (str): The prompt to send to the model.
Returns:
str: The content of the model's response to the prompt.
"""
response = await self.model.generate_content_async(
prompt,
generation_config=self.model_kwargs,
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}
)
return response.text
def generate_prompt(self, context: str, retrieval_question: str) -> str:
"""
Generates a structured prompt for querying the model, based on a given context and retrieval question.
Args:
context (str): The context or background information relevant to the question.
retrieval_question (str): The specific question to be answered by the model.
Returns:
str: The text prompt
"""
return self.prompt_structure.format(
retrieval_question=retrieval_question,
context=context)
def encode_text_to_tokens(self, text: str) -> list[int]:
"""
Encodes a given text string to a sequence of tokens using the model's tokenizer.
Args:
text (str): The text to encode.
Returns:
list[int]: A list of token IDs representing the encoded text.
"""
return self.tokenizer.encode(text)
def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str:
"""
Decodes a sequence of tokens back into a text string using the model's tokenizer.
Args:
tokens (list[int]): The sequence of token IDs to decode.
context_length (Optional[int], optional): An optional length specifying the number of tokens to decode. If not provided, decodes all tokens.
Returns:
str: The decoded text string.
"""
return self.tokenizer.decode(tokens[:context_length])