genai-on-vertex-ai/gemini/needle_in_a_haystack/needlehaystack/run.py (54 lines of code) (raw):
from dataclasses import dataclass, field
from typing import Optional
from jsonargparse import CLI
from . import LLMNeedleHaystackTester
from .evaluators import Evaluator, GoogleEvaluator
from .providers import ModelProvider, Google
@dataclass
class CommandArgs():
gcp_project_id: str
provider: str = "google"
evaluator: str = "google"
model_name: str = "gemini-2.0-flash-001"
evaluator_model_name: Optional[str] = "gemini-2.0-flash-001"
dynamic_needle: Optional[bool] = True
needle: Optional[str] = "\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n"
haystack_dir: Optional[str] = "PaulGrahamEssays"
retrieval_question: Optional[str] = "What is the best thing to do in San Francisco?"
results_version: Optional[int] = 1
context_lengths_min: Optional[int] = 1000
context_lengths_max: Optional[int] = 16000
context_lengths_num_intervals: Optional[int] = 35
context_lengths: Optional[list[int]] = None
document_depth_percent_min: Optional[int] = 0
document_depth_percent_max: Optional[int] = 100
document_depth_percent_intervals: Optional[int] = 35
document_depth_percents: Optional[list[int]] = None
document_depth_percent_interval_type: Optional[str] = "linear"
num_concurrent_requests: Optional[int] = 1
save_results: Optional[bool] = True
save_contexts: Optional[bool] = True
final_context_length_buffer: Optional[int] = 200
seconds_to_sleep_between_completions: Optional[float] = None
print_ongoing_status: Optional[bool] = True
def get_model_to_test(args: CommandArgs) -> ModelProvider:
"""
Determines and returns the appropriate model provider based on the provided command arguments.
Args:
args (CommandArgs): The command line arguments parsed into a CommandArgs dataclass instance.
Returns:
ModelProvider: An instance of the specified model provider class.
Raises:
ValueError: If the specified provider is not supported.
"""
match args.provider.lower():
case "google":
return Google(model_name=args.model_name, project_id=args.gcp_project_id)
case _:
raise ValueError(f"Invalid provider: {args.provider}")
def get_evaluator(args: CommandArgs) -> Evaluator:
"""
Selects and returns the appropriate evaluator based on the provided command arguments.
Args:
args (CommandArgs): The command line arguments parsed into a CommandArgs dataclass instance.
Returns:
Evaluator: An instance of the specified evaluator class.
Raises:
ValueError: If the specified evaluator is not supported.
"""
match args.evaluator.lower():
case "google":
return GoogleEvaluator(project_id=args.gcp_project_id,
model_name=args.evaluator_model_name)
case _:
raise ValueError(f"Invalid evaluator: {args.evaluator}")
def main():
"""
The main function to execute the testing process based on command line arguments.
It parses the command line arguments, selects the appropriate model provider and evaluator,
and initiates the testing process either for single-needle or multi-needle scenarios.
"""
args = CLI(CommandArgs, as_positional=False)
args.model_to_test = get_model_to_test(args)
args.evaluator = get_evaluator(args)
tester = LLMNeedleHaystackTester(**args.__dict__)
tester.start_test()
if __name__ == "__main__":
main()