genai-on-vertex-ai/gemini/evals_playbook/utils/evals_playbook.py (389 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import itertools
import random
import string
import hashlib
import uuid
import json
import re
import datetime
import pandas as pd
from utils import config as cfg
from google.cloud import bigquery
from google.cloud import aiplatform
from google.cloud import storage
from vertexai.evaluation import EvalResult
from sqlalchemy.ext.automap import automap_base
from sqlalchemy import create_engine, MetaData, Column, String, Table
from utils.config import PROJECT_ID, LOCATION, STAGING_BUCKET
BQ_TABLE_MAP = {
"tasks": {"table_name": cfg.BQ_T_EVAL_TASKS, "keys": ["task_id"]},
"experiments": {"table_name": cfg.BQ_T_EXPERIMENTS, "keys": ["task_id", "experiment_id"]},
"prompts": {"table_name": cfg.BQ_T_PROMPTS, "keys": ["prompt_id"]},
"datasets": {"table_name": cfg.BQ_T_DATASETS, "keys": ["dataset_id"]},
"runs": {"table_name": cfg.BQ_T_EVAL_RUNS, "keys": ["task_id", "experiment_id", "run_id"]},
"run_details": {"table_name": cfg.BQ_T_EVAL_RUN_DETAILS, "keys": ["task_id", "experiment_id", "run_id", "dataset_row_id"]}
}
def get_table_name_keys(table_class):
if table_class not in BQ_TABLE_MAP:
raise ValueError(f"Invalid table class '{table_class}'. Supported {list(BQ_TABLE_MAP.keys())}")
table = BQ_TABLE_MAP[table_class]
return table["table_name"], table["keys"]
def get_db_object(table_class):
table_name, update_keys = get_table_name_keys(table_class)
update_key_cols = [Column(key, String, primary_key=True) for key in update_keys]
return table_name, update_key_cols
def get_db_classes():
# Define engine, metadata and session
engine = create_engine(f'bigquery://{cfg.PROJECT_ID}')
metadata = MetaData()
# Auto populate metadata
for table_class in BQ_TABLE_MAP:
table_name, update_key_cols = get_db_object(table_class)
Table(table_name, metadata, *update_key_cols, autoload_with=engine, schema=cfg.BQ_DATASET_ID)
# create objects
Base = automap_base(metadata=metadata)
Base.prepare()
return Base
def format_dt(dt: datetime.datetime):
return dt.strftime("%m-%d-%Y_%H:%M:%S")
def clean_string(source_string):
clean_spaces = re.sub(' ', '_', source_string)
return re.sub('[^a-zA-Z0-9 _\n\.]', '', clean_spaces.lower())
def write_to_gcs(gcs_path, data):
if not gcs_path.startswith("gs://"):
raise Exception(f"Invalid Cloud Storage path {gcs_path}. Pass a valid path starting with gs://")
# check if data is a file or a string
UPLOAD_AS_FILE = False
if os.path.exists(data):
UPLOAD_AS_FILE = True
bucket = gcs_path.split("/")[2]
object = "/".join(gcs_path.split("/")[3:])
# Initialize the Cloud Storage client
storage_client = storage.Client()
# Get the bucket object
bucket = storage_client.bucket(bucket)
blob = bucket.blob(object)
if UPLOAD_AS_FILE:
blob.upload_from_filename(data)
else:
blob.upload_from_string(data)
return blob.self_link
def generate_uuid(text: str):
"""Generate a uuid based on text"""
hex_string = hashlib.md5(text.encode('UTF-8')).hexdigest()
random_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
return str(uuid.UUID(hex=hex_string)) + "-" + random_id
class Evals():
def __init__(self):
Base = get_db_classes()
self.Task = Base.classes.eval_tasks
self.Experiment = Base.classes.eval_experiments
self.Prompt = Base.classes.eval_prompts
self.EvalDataset = Base.classes.eval_datasets
self.EvalRunDetail = Base.classes.eval_run_details
self.EvalRun = Base.classes.eval_runs
def log_task(self, task):
try:
if isinstance(task, self.Task):
task = task.__dict__
if "_sa_instance_state" in task: task.pop("_sa_instance_state")
if not isinstance(task, dict):
raise Exception(f"Invalid task object. Expected: `dict`. Actual: {type(task)}")
self._upsert("tasks", task)
except Exception as e:
print(f"Failed to log task due to following error.")
raise e
def log_prompt(self, prompt):
try:
if isinstance(prompt, self.Prompt):
prompt = prompt.__dict__
if "_sa_instance_state" in prompt: prompt.pop("_sa_instance_state")
if not isinstance(prompt, dict):
raise Exception(f"Invalid task object. Expected: `dict`. Actual: {type(prompt)}")
self._upsert("prompts", prompt)
except Exception as e:
print(f"Failed to log prompt due to following error.")
raise e
def _get_all(self, table_class, limit_offset=20, as_dict=False):
client = bigquery.Client(project=cfg.PROJECT_ID)
table_name = BQ_TABLE_MAP.get(table_class).get("table_name")
table_id = f"{cfg.PROJECT_ID}.{cfg.BQ_DATASET_ID}.{table_name}"
table = client.get_table(table_id)
cols = [schema.name for schema in table.schema]
sql = f"""
SELECT {", ".join(cols)}
FROM `{table_id}`
ORDER BY create_datetime DESC
LIMIT {limit_offset}
"""
df = client.query_and_wait(sql).to_dataframe()
if as_dict:
return df.to_dict(orient='records')
else:
return df
def get_all_tasks(self, limit_offset=20, as_dict=False):
return self._get_all("tasks", limit_offset, as_dict)
def get_all_experiments(self, limit_offset=20, as_dict=False):
return self._get_all("experiments", limit_offset, as_dict)
def get_all_prompts(self, limit_offset=20, as_dict=False):
return self._get_all("prompts", limit_offset, as_dict)
def get_all_eval_runs(self, limit_offset=20, as_dict=False):
return self._get_all("runs", limit_offset, as_dict)
def get_all_eval_run_details(self, limit_offset=20, as_dict=False):
return self._get_all("run_details", limit_offset, as_dict)
def _get_one(self, table_class, where_keys, limit_offset=1, as_dict=False):
client = bigquery.Client(project=cfg.PROJECT_ID)
table_name = BQ_TABLE_MAP.get(table_class).get("table_name")
table_id = f"{cfg.PROJECT_ID}.{cfg.BQ_DATASET_ID}.{table_name}"
table = client.get_table(table_id)
cols = [schema.name for schema in table.schema]
if where_keys:
where_clause = "WHERE "
where_clause += "AND ".join([f"{k} = '{v}'"for k,v in where_keys.items()])
sql = f"""
SELECT {", ".join(cols)}
FROM `{table_id}`
{where_clause}
ORDER BY create_datetime DESC
LIMIT {limit_offset}
"""
df = client.query_and_wait(sql).to_dataframe()
if as_dict:
return df.to_dict(orient='records')
else:
return df
def get_experiment(self, experiment_id, task_id: str="", as_dict=False):
where_keys = {}
if experiment_id:
where_keys["experiment_id"] = experiment_id
if task_id:
where_keys["task_id"] = task_id
return self._get_one("experiments", where_keys, as_dict=as_dict)
else:
raise Exception(f"Experiment ID is required.")
def get_prompt(self, prompt_id, as_dict=False):
where_keys = {}
if prompt_id:
where_keys["prompt_id"] = prompt_id
return self._get_one("prompts", where_keys, as_dict=as_dict)
else:
raise Exception(f"Prompt ID is required.")
def get_eval_runs(self, experiment_id, experiment_run_id: str="", task_id: str="", as_dict=False):
where_keys = {}
if not experiment_run_id:
print("[INFO] experiment_run_id not passed. Showing last 5 runs (if available).")
limit_offset = 5
else:
where_keys["run_id"] = experiment_run_id
limit_offset = 1
if experiment_id:
where_keys["experiment_id"] = experiment_id
if task_id:
where_keys["task_id"] = task_id
# get experiment
exp_df = self.get_experiment(experiment_id=experiment_id)
exp_df = exp_df[["experiment_id", "experiment_desc", "prompt_id", "model_endpoint", "model_name", "generation_config"]]
model_config_df_exp = pd.json_normalize(exp_df['generation_config'])
exp_df = pd.concat([exp_df.drop('generation_config', axis=1), model_config_df_exp], axis=1)
# get metrics
metrics_df = self._get_one("runs", where_keys, limit_offset=limit_offset, as_dict=False)
metrics_df = metrics_df[['experiment_id', 'run_id', 'metrics', 'task_id', 'create_datetime', 'update_datetime', 'tags']]
metrics_df = pd.merge(exp_df, metrics_df, on='experiment_id', how='left')
metrics_df['metrics'] = metrics_df['metrics'].apply(eval)
metrics_df_exp = pd.json_normalize(metrics_df['metrics'])
metrics_df = pd.concat([metrics_df.drop('metrics', axis=1), metrics_df_exp], axis=1)
if as_dict:
return metrics_df.T.to_dict(orient='records')
else:
return metrics_df.T
else:
raise Exception(f"experiment_id is required.")
def compare_eval_runs(self, experiment_run_ids, as_dict=False):
if not experiment_run_ids:
raise Exception(f"experiment_run_ids are required to compare runs")
if isinstance(experiment_run_ids, str):
experiment_run_ids = [experiment_run_ids]
if isinstance(experiment_run_ids, list):
experiment_run_ids = ", ".join([f"'{run}'" for run in experiment_run_ids])
table_prefix = f"{cfg.PROJECT_ID}.{cfg.BQ_DATASET_ID}"
client = bigquery.Client(project=cfg.PROJECT_ID)
sql = f"""
SELECT
runs.task_id,
runs.run_id,
runs.experiment_id,
exp.experiment_desc,
exp.model_endpoint,
exp.model_name,
exp.generation_config,
prompt.prompt_template,
prompt.system_instruction,
runs.metrics,
runs.create_datetime
FROM
`{table_prefix}.{BQ_TABLE_MAP.get('runs').get('table_name')}` runs
JOIN
`{table_prefix}.{BQ_TABLE_MAP.get('experiments').get('table_name')}` exp
ON
runs.experiment_id = exp.experiment_id
LEFT JOIN
`{table_prefix}.{BQ_TABLE_MAP.get('prompts').get('table_name')}` prompt
ON
exp.prompt_id = prompt.prompt_id
WHERE runs.run_id IN ({experiment_run_ids})
ORDER BY runs.create_datetime DESC
"""
df = client.query_and_wait(sql).to_dataframe()
# format metrics
df['metrics'] = df['metrics'].apply(eval)
df['generation_config'] = df['generation_config'].apply(eval)
# print(f'df: {df.columns}')
# print(f"generation_config: {df['generation_config']}")
df_metrics_exp = pd.json_normalize(df['metrics'])
df_config_exp = pd.json_normalize(df['generation_config'])
df = pd.concat([df.drop(['metrics', 'generation_config'], axis=1), df_metrics_exp, df_config_exp], axis=1)
# print(f'df_config_exp: {df_config_exp.columns}')
# print(f'df_metrics_exp: {df_metrics_exp.columns}')
# print(f'df: {df.columns}')
if as_dict:
return df.T.to_dict(orient='records')
else:
return df.T
def grid_search(self, task_id, experiment_run_ids, opt_metrics, opt_params):
"""
Performs grid search on the evaluation results and returns the best parameter combinations for each metric.
Args:
task_id: The specific task ID to filter the results.
experiment_run_ids: List of experiment run IDs to include in the grid search.
opt_metrics: List of metrics to optimize (e.g., ["ROUGE_1", "BLEU"]).
opt_params: List of parameters to consider in the grid search (e.g., ["prompt_template", "temperature"]).
Returns:
A dictionary where keys are the optimization metrics and values are the corresponding best parameter combinations.
"""
# Get
grid_df = (self.compare_eval_runs(experiment_run_ids)).T
# Filter the grid_df based on the task_id
filtered_df = grid_df[grid_df['task_id'] == task_id]
# Initialize a dictionary to store the best parameter combinations for each metric
best_params = {}
for metric in opt_metrics:
# Convert metric name to the corresponding column name in grid_df
metric_mean_col = metric.lower().replace("_", "_") + "/mean"
metric_std_col = metric.lower().replace("_", "_") + "/std"
# Find the row with the highest value for the given metric
best_row = filtered_df.loc[filtered_df[metric_mean_col].idxmax()]
# Extract the values of the optimization parameters and the std from the best row
best_params[metric] = {
"params": {param: best_row[param] for param in opt_params},
"metric_mean": best_row[metric_mean_col],
"metric_std": best_row[metric_std_col]
}
return best_params
def get_eval_run_detail(self, experiment_run_id, task_id: str="", limit_offset=100, as_dict=False):
where_keys = {}
if not experiment_run_id:
raise Exception(f"experiment_run_id is required is to get run detail.")
where_keys["run_id"] = experiment_run_id
if task_id:
where_keys["task_id"] = task_id
details_df = self._get_one("run_details", where_keys, limit_offset=limit_offset, as_dict=False)
if as_dict:
return details_df.T.to_dict(orient='records')
else:
# print(f"[INFO] Showing top {limit_offset} rows. For viewing more # of rows, pass `limit_offset`.")
return details_df
def _upsert(self, table_class, rows, debug=False):
"""Inserts or updates rows in the specified BigQuery table.
Args:
table_name: The name of the table.
rows: A list of dictionaries where each dictionary represents a row.
update_keys: A list of keys to use for updating existing rows.
"""
table_name, update_keys = get_table_name_keys(table_class)
if isinstance(rows, dict):
rows = [rows]
# Validate that update keys are present in all rows
all_keys = set().union(*(d.keys() for d in rows))
for row in rows:
for key in update_keys:
if key not in row:
raise ValueError(f"Update key '{key}' not found in row: {row}")
# Get BigQuery table schema
table_id = f"{cfg.PROJECT_ID}.{cfg.BQ_DATASET_ID}.{table_name}"
client = bigquery.Client(project=cfg.PROJECT_ID)
table = client.get_table(table_id)
schema = {schema.name:schema.field_type for schema in table.schema}
# Construct the MERGE query dynamically
merge_query = f"""
MERGE INTO `{table_id}` AS target
USING (
SELECT * FROM UNNEST(@rows)
) AS source
ON {" AND ".join(f"target.{key} = source.{key}" for key in update_keys)}
"""
if update_keys:
merge_query += f""" WHEN MATCHED THEN
UPDATE SET {", ".join(f"target.{key} = source.{key}" for key in all_keys if key not in update_keys + ['create_datetime'])}
"""
merge_query += f""" WHEN NOT MATCHED THEN
INSERT({", ".join([key for key in all_keys])})
VALUES({", ".join(f"source.{key}" for key in all_keys)})
"""
# Convert rows to BigQuery format
rows_for_query = []
for row in rows:
row_for_query = []
for key, val in row.items():
field_type = schema.get(key)
if field_type == "BOOLEAN":
field_type = "BOOL"
if (val is not None):
if isinstance(val, datetime.datetime):
val = val.isoformat()
if isinstance(val, list):
row_for_query.append(bigquery.ArrayQueryParameter(key, field_type, val))
else:
row_for_query.append(bigquery.ScalarQueryParameter(key, field_type, val))
rows_for_query.append(bigquery.StructQueryParameter("x", *row_for_query))
job_config = bigquery.QueryJobConfig(
query_parameters=[bigquery.ArrayQueryParameter("rows", "STRUCT", rows_for_query)]
)
# -- DEBUGGING --
print("MERGE Query:")
print(merge_query)
print("\nRows:")
print(rows_for_query)
# -- END DEBUGGING --
query_job = client.query(merge_query, job_config=job_config)
query_job.result() # Wait for the MERGE to complete
def log_experiment(self,
task_id,
experiment_id,
prompt,
model,
metric_config,
experiment_desc="",
is_streaming=False,
tags=[],
metadata={}):
# create experiment object
experiment = self.Experiment(
experiment_id=experiment_id,
experiment_desc=experiment_desc,
task_id=task_id,
prompt_id = prompt.prompt_id,
elapsed_time = 0
)
# add model information
experiment.model_name = model._model_name.split("/")[-1]
experiment.model_endpoint = aiplatform.constants.base.API_BASE_PATH
experiment.is_streaming = is_streaming
# add generation config
if model._generation_config and isinstance(model._generation_config, dict):
experiment.generation_config = json.dumps(model._generation_config)
# add safety settings
if model._safety_settings:
if isinstance(model._safety_settings, dict):
safety_settings_as_dict = {
category.name: threshold.name
for category, threshold in model._safety_settings.items()
}
elif isinstance(model._safety_settings, list):
safety_settings_as_dict = {
s.to_dict().get("category", "HARM_CATEGORY_UNSPECIFIED"):s.to_dict().get("threshold")
for s in model._safety_settings
}
else:
safety_settings_as_dict = {}
experiment.safety_settings = json.dumps(safety_settings_as_dict)
# add metric config
experiment.metric_config = str(metric_config)
# additional fields
experiment.create_datetime = datetime.datetime.now()
experiment.update_datetime = datetime.datetime.now()
experiment.tags = tags
if isinstance(metadata, dict):
experiment.metadata = json.dumps(metadata)
try:
if isinstance(experiment, self.Experiment):
experiment = experiment.__dict__
if "_sa_instance_state" in experiment: experiment.pop("_sa_instance_state")
if not isinstance(experiment, dict):
raise Exception(f"Invalid task object. Expected: `dict`. Actual: {type(experiment)}")
self._upsert("experiments", experiment)
except Exception as e:
print(f"Failed to log experiment due to following error.")
raise e
return experiment
def save_prompt_template(self, task_id, experiment_id, prompt_id, prompt_template):
# Construct the full file path in the bucket
fmt_prompt_id = clean_string(prompt_id)
prefix = f'{task_id}/prompts/{experiment_id}'
gcs_file_path = f'gs://{STAGING_BUCKET}/{prefix}/template_{fmt_prompt_id}.txt'
# write to GCS
write_to_gcs(gcs_file_path, prompt_template)
print(f"Prompt template saved to {gcs_file_path} successfully!")
def save_prompt(self, text, run_path, blob_name):
"""
Saves the given text to a Google Cloud Storage bucket and returns the blob's URI.
Args:
text: The text to be saved.
bucket_name: The name of the GCS bucket.
blob_name: The desired name for the blob (file) in the bucket.
Returns:
The URI of the created blob.
"""
# Construct the full file path in the bucket
gcs_file_path = f'gs://{STAGING_BUCKET}/{run_path}/{blob_name}.txt'
blob_link = write_to_gcs(gcs_file_path, text)
return blob_link
def log_eval_run(self,
experiment_run_id: str,
experiment,
eval_result,
run_path,
tags=[],
metadata={}):
# log run details
if not isinstance(eval_result, EvalResult):
raise Exception(f"Invalid eval_result object. Expected: `vertexai.evaluation.EvalResult` Actual: {type(eval_result)}")
if isinstance(experiment, dict):
experiment = self.Experiment(**experiment)
if not isinstance(experiment, self.Experiment):
raise Exception(f"Invalid experiment object. Expected: `Experiment` Actual: {type(experiment)}")
# get run details from the Rapid Eval evaluation task
detail_df = eval_result.metrics_table.to_dict(orient="records")
summary_dict = eval_result.summary_metrics
non_metric_keys = ['context', 'reference', 'instruction', 'dataset_row_id', 'completed_prompt', 'response']
# report_df = eval_result.metrics_table
print(f'detail_df.keys: {detail_df[0].keys()}')
# prepare run details
run_details = []
for row in detail_df:
row.get("prompt")
metrics = {k: row[k] for k in row if k not in non_metric_keys}
run_detail = dict(
run_id=experiment_run_id,
experiment_id=experiment.experiment_id,
task_id=experiment.task_id,
dataset_row_id=row.get("dataset_row_id"),
system_instruction=row.get("instruction"),
input_prompt_gcs_uri=self.save_prompt(row.get("prompt"), run_path, row.get("dataset_row_id")),
output_text=row.get("response"),
ground_truth=row.get("reference"),
metrics=json.dumps(metrics),
# additional fields
latencies=[],
create_datetime=datetime.datetime.now(),
update_datetime=datetime.datetime.now(),
tags=tags,
metadata=json.dumps(metadata) if isinstance(metadata, dict) else None
)
run_details.append(run_detail)
try:
self._upsert("run_details", run_details)
except Exception as e:
print(f"Failed to log run details due to following error.")
raise e
# prepare run summary metrics
run_summary = dict(
run_id=experiment_run_id,
experiment_id=experiment.experiment_id,
task_id=experiment.task_id,
# dataset_row_id = experiment.dataset_row_id,
metrics=json.dumps(summary_dict),
# additional fields
create_datetime=datetime.datetime.now(),
update_datetime=datetime.datetime.now(),
tags=tags,
metadata=json.dumps(metadata) if isinstance(metadata, dict) else None
)
try:
self._upsert("runs", run_summary)
except Exception as e:
print(f"Failed to log run summary due to following error.")
raise e