genai-on-vertex-ai/gemini/needle_in_a_haystack/needlehaystack/run.py (54 lines of code) (raw):

from dataclasses import dataclass, field from typing import Optional from jsonargparse import CLI from . import LLMNeedleHaystackTester from .evaluators import Evaluator, GoogleEvaluator from .providers import ModelProvider, Google @dataclass class CommandArgs(): gcp_project_id: str provider: str = "google" evaluator: str = "google" model_name: str = "gemini-2.0-flash-001" evaluator_model_name: Optional[str] = "gemini-2.0-flash-001" dynamic_needle: Optional[bool] = True needle: Optional[str] = "\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n" haystack_dir: Optional[str] = "PaulGrahamEssays" retrieval_question: Optional[str] = "What is the best thing to do in San Francisco?" results_version: Optional[int] = 1 context_lengths_min: Optional[int] = 1000 context_lengths_max: Optional[int] = 16000 context_lengths_num_intervals: Optional[int] = 35 context_lengths: Optional[list[int]] = None document_depth_percent_min: Optional[int] = 0 document_depth_percent_max: Optional[int] = 100 document_depth_percent_intervals: Optional[int] = 35 document_depth_percents: Optional[list[int]] = None document_depth_percent_interval_type: Optional[str] = "linear" num_concurrent_requests: Optional[int] = 1 save_results: Optional[bool] = True save_contexts: Optional[bool] = True final_context_length_buffer: Optional[int] = 200 seconds_to_sleep_between_completions: Optional[float] = None print_ongoing_status: Optional[bool] = True def get_model_to_test(args: CommandArgs) -> ModelProvider: """ Determines and returns the appropriate model provider based on the provided command arguments. Args: args (CommandArgs): The command line arguments parsed into a CommandArgs dataclass instance. Returns: ModelProvider: An instance of the specified model provider class. Raises: ValueError: If the specified provider is not supported. """ match args.provider.lower(): case "google": return Google(model_name=args.model_name, project_id=args.gcp_project_id) case _: raise ValueError(f"Invalid provider: {args.provider}") def get_evaluator(args: CommandArgs) -> Evaluator: """ Selects and returns the appropriate evaluator based on the provided command arguments. Args: args (CommandArgs): The command line arguments parsed into a CommandArgs dataclass instance. Returns: Evaluator: An instance of the specified evaluator class. Raises: ValueError: If the specified evaluator is not supported. """ match args.evaluator.lower(): case "google": return GoogleEvaluator(project_id=args.gcp_project_id, model_name=args.evaluator_model_name) case _: raise ValueError(f"Invalid evaluator: {args.evaluator}") def main(): """ The main function to execute the testing process based on command line arguments. It parses the command line arguments, selects the appropriate model provider and evaluator, and initiates the testing process either for single-needle or multi-needle scenarios. """ args = CLI(CommandArgs, as_positional=False) args.model_to_test = get_model_to_test(args) args.evaluator = get_evaluator(args) tester = LLMNeedleHaystackTester(**args.__dict__) tester.start_test() if __name__ == "__main__": main()