genai-on-vertex-ai/gemini/model_upgrades/text_classification/vertex_script/eval.py (40 lines of code) (raw):

# Copyright 2025 Google LLC # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import vertexai from datetime import datetime from vertexai.evaluation import EvalTask, CustomMetric from vertexai.generative_models import GenerativeModel def case_insensitive_match(record: dict[str, str]) -> dict[str, float]: response = record["response"].strip().lower() label = record["reference"].strip().lower() return {"accuracy": 1.0 if label == response else 0.0} def run_eval(experiment_name: str, baseline_model: str, candidate_model: str, prompt_template_local_path: str, dataset_local_path: str): timestamp = f"{datetime.now().strftime('%b-%d-%H-%M-%S')}".lower() prompt_template = open(prompt_template_local_path).read() task = EvalTask( dataset=dataset_local_path, metrics=[CustomMetric(name="accuracy", metric_function=case_insensitive_match)], experiment=experiment_name ) baseline_results = task.evaluate( experiment_run_name=f"{timestamp}-{baseline_model.replace('.', '-')}", prompt_template=prompt_template, model=GenerativeModel(baseline_model) ) candidate_results = task.evaluate( experiment_run_name=f"{timestamp}-{candidate_model.replace('.', '-')}", prompt_template=prompt_template, model=GenerativeModel(candidate_model) ) print("Baseline model accuracy:", baseline_results.summary_metrics["accuracy/mean"]) print("Candidate model accuracy:", candidate_results.summary_metrics["accuracy/mean"]) if __name__ == '__main__': if os.getenv("PROJECT_ID", "your-project-id") == "your-project-id": raise ValueError("Please configure your Google Cloud Project ID.") vertexai.init(project=os.getenv("PROJECT_ID"), location='us-central1') run_eval( experiment_name = 'evals-classifier-demo', baseline_model = 'gemini-1.5-flash-001', candidate_model = 'gemini-2.0-flash-001', prompt_template_local_path = 'prompt_template.txt', dataset_local_path = 'dataset.jsonl' )