def get_model_to_test()

in genai-on-vertex-ai/gemini/needle_in_a_haystack/needlehaystack/run.py [0:0]


def get_model_to_test(args: CommandArgs) -> ModelProvider:
    """
    Determines and returns the appropriate model provider based on the provided command arguments.
    
    Args:
        args (CommandArgs): The command line arguments parsed into a CommandArgs dataclass instance.
        
    Returns:
        ModelProvider: An instance of the specified model provider class.
    
    Raises:
        ValueError: If the specified provider is not supported.
    """
    match args.provider.lower():
        case "google":
            return Google(model_name=args.model_name, project_id=args.gcp_project_id)
        case _:
            raise ValueError(f"Invalid provider: {args.provider}")