gemini/prompts/prompt_optimizer/vapo_lib.py (785 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=too-many-lines """Utility functions and classes for the VAPO notebook.""" import csv import io import json import random import re import string import subprocess from typing import Any, Callable, Union from IPython.core.display import DisplayHandle from IPython.display import HTML, display from google.cloud import aiplatform, storage import ipywidgets as widgets import jinja2 import jinja2.meta from jsonschema import ValidationError, validate import pandas as pd import plotly.graph_objects as go from tenacity import retry, wait_random_exponential from tensorflow.io import gfile from vertexai import generative_models from vertexai.evaluation import EvalTask from vertexai.generative_models import ( Content, GenerationConfig, GenerativeModel, Part, SafetySetting, Tool, ToolConfig, ) def is_target_required_metric(eval_metric: str) -> bool: """Check if the metric requires the target label.""" return eval_metric in [ "bleu", "exact_match", "question_answering_correctness", "rouge_1", "rouge_2", "rouge_l", "rouge_l_sum", "tool_call_valid", "tool_name_match", "tool_parameter_key_match", "tool_parameter_kv_match", ] def is_run_target_required(eval_metric_types: list[str], source_model: str) -> bool: """Check if the run requires the target label.""" if source_model: return False label_required = False for metric in eval_metric_types: label_required = label_required or is_target_required_metric(metric) return label_required _TARGET_KEY = "target" def load_file_from_gcs(dataset: str) -> str: """Loads the file from GCS and returns it as a string.""" if dataset.startswith("gs://"): with gfile.GFile(dataset, "r") as f: return f.read() else: raise ValueError( "Unsupported file location. Only GCS paths starting with 'gs://' are" " supported." ) def parse_jsonl(data_str: str) -> list[dict[str, str]]: """Parses the content of a JSONL file and returns a list of dictionaries.""" data = [] lines = data_str.splitlines() for line in lines: if line: try: data.append(json.loads(line)) except json.JSONDecodeError as e: raise ValueError( f"Error decoding JSON on line: {line}. Error: {e}" ) from e return data def parse_and_validate_csv(data_str: str) -> list[dict[str, str]]: """Parses and validates the content of a CSV file and returns a list of dictionaries.""" data = [] csv_reader = csv.reader(io.StringIO(data_str)) # Extract and validate headers try: headers = next(csv_reader) if not headers: raise ValueError("The CSV file has an empty or invalid header row.") except StopIteration as e: raise ValueError("The CSV file is empty.") from e # Validate and process rows for row_number, row in enumerate(csv_reader, start=2): if len(row) != len(headers): raise ValueError( f"Row {row_number} has an inconsistent number of fields. " f"Expected {len(headers)} fields but found {len(row)}." ) # Create dictionary for each row using headers as keys item = dict(zip(headers, row)) data.append(item) return data def load_dataset(dataset: str) -> list[dict[str, str]]: """Loads and parses the dataset based on its file type ('.jsonl' or '.csv').""" # Load the file from GCS data_str = load_file_from_gcs(dataset) # Parse based on file type if dataset.endswith(".jsonl"): return parse_jsonl(data_str) if dataset.endswith(".csv"): return parse_and_validate_csv(data_str) raise ValueError( "Unsupported file type. Please provide a file with '.jsonl' or '.csv'" " extension." ) def validate_prompt_and_data( template: str, dataset_path: str, placeholder_to_content: str, label_enforced: bool, ) -> None: """Validates the prompt template and the dataset.""" data = load_dataset(dataset_path) placeholder_to_content_json = json.loads(placeholder_to_content) template = re.sub(r"(?<!{){(?!{)(\s*\w+\s*)(?<!})}(?!})", r"{{\1}}", template) env = jinja2.Environment() try: parsed_content = env.parse(template) except jinja2.exceptions.TemplateSyntaxError as e: raise ValueError(f"Invalid template: {template}") from e template_variables = jinja2.meta.find_undeclared_variables(parsed_content) extra_keys = set() for ex in data: ex.update(placeholder_to_content_json) missing_keys = [key for key in template_variables if key not in ex] extra_keys.update([key for key in ex if key not in template_variables]) if label_enforced: if _TARGET_KEY not in ex: raise ValueError( f"The example {ex} doesn't have a key corresponding to the target" f" var: {_TARGET_KEY}" ) if not ex[_TARGET_KEY]: raise ValueError(f"The following example has an empty target: {ex}") if missing_keys: raise ValueError( f"The example {ex} doesn't have a key corresponding to following" f" template vars: {missing_keys}" ) if extra_keys: raise Warning( "Warning: extra keys in the examples not used in the prompt template" f" template {extra_keys}" ) def run_custom_job( display_name: str, container_uri: str, container_args: dict[str, str], ) -> str: """A sample to create custom jobs.""" worker_pool_specs = [ { "replica_count": 1, "container_spec": { "image_uri": container_uri, "args": [f"--{k}={v}" for k, v in container_args.items()], }, "machine_spec": { "machine_type": "n1-standard-4", }, } ] custom_job = aiplatform.CustomJob( display_name=display_name, worker_pool_specs=worker_pool_specs, ) custom_job.submit() return custom_job def run_apd(config: dict[str, str], bucket_uri: str, display_name: str) -> str: """A function to the vertex prompt optimizer.""" print(f"\n\nJob display name: {display_name}") version = "preview_v1_0" container_uri = "us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/apd" config_path = f"{bucket_uri}/{display_name}/input_config.json" with gfile.GFile(config_path, "w") as f: json.dump(config, f) aiplatform.init( project=config["project"], location=config["target_model_location"], staging_bucket=f"{bucket_uri}/{display_name}", ) return run_custom_job( display_name=display_name, container_uri=f"{container_uri}:{version}", container_args={"config": config_path}, ) def update_best_display( df: pd.DataFrame, textarea: widgets.Textarea, best_score_label: widgets.Label, eval_metric: str, ) -> None: """Update the best prompt display.""" df["score"] = df[f"metrics.{eval_metric}/mean"] best_template = df.loc[df["score"].argmax(), "prompt"] best_score = df.loc[df["score"].argmax(), "score"] original_score = df.loc[0, "score"] def placeholder_llm() -> str: return "{{llm()}}" env = jinja2.Environment(loader=jinja2.BaseLoader()) env.globals["llm"] = placeholder_llm best_template = best_template.replace("store('answer', llm())", "llm()") textarea.value = best_template improvement = best_score - original_score no_improvement_str = "\nNo better template is found yet." if not improvement else "" best_score_label.value = ( f"Score: {best_score}" f" Improvement: {improvement: .3f} {no_improvement_str}" ) def generate_dataframe(filename: str) -> pd.DataFrame: """Generates a pandas dataframe from a json file.""" if not gfile.exists(filename): return pd.DataFrame() with gfile.GFile(filename, "r") as f: try: data = json.load(f) except json.JSONDecodeError: return pd.DataFrame() return pd.json_normalize(data) def left_aligned_df_html(df: pd.DataFrame) -> HTML: """Displays a Pandas DataFrame in Colab with left-aligned values.""" # Convert to HTML table, but keep the HTML in a variable html_table = df.to_html(index=False, classes="left-aligned") # Add CSS styling to left-align table data cells and override default styles styled_html = f""" <style> .left-aligned td, .left-aligned th {{ text-align: left !important; }} </style> {html_table} """ # Display the styled HTML table return HTML(styled_html) def extract_top_level_function_name(source_code: str) -> str | None: """Extract the top level function name from the source code.""" match = re.search(r"^def\s+([a-zA-Z_]\w*)\s*\(", source_code, re.MULTILINE) if match: return match.group(1) return None class ProgressForm: """A class to display the progress of the optimization job.""" # pylint: disable=too-many-instance-attributes def __init__(self, params: dict[str, str]) -> None: """Initialize the progress form.""" self.instruction_progress_bar = None self.instruction_display = None self.instruction_best = None self.instruction_score = None self.demo_progress_bar = None self.demo_display = None self.demo_best = None self.demo_score = None self.job_state_display = display( HTML("<span>Job State: Not Started!</span>"), display_id=True ) self.status_display = display(HTML(""), display_id=True) if params["optimization_mode"] in ["instruction", "instruction_and_demo"]: ( self.instruction_progress_bar, self.instruction_display, self.instruction_best, self.instruction_score, ) = self.create_progress_ui("Instruction", params["num_steps"]) if params["optimization_mode"] in ["demonstration", "instruction_and_demo"]: ( self.demo_progress_bar, self.demo_display, self.demo_best, self.demo_score, ) = self.create_progress_ui( "Demonstration", params["num_demo_set_candidates"] ) if len(params["eval_metrics_types"]) == 1: self.eval_metric = params["eval_metrics_types"][0] else: self.eval_metric = "composite_metric" self.output_path = params["output_path"] self.instruction_df = None self.demo_df = None # pylint: disable=too-many-positional-arguments,too-many-arguments def update_progress( self, progress_bar: widgets.IntProgress | None, templates_file: str, df: pd.DataFrame | None, df_display: DisplayHandle, best_textarea: widgets.Textarea, best_score: widgets.Label, eval_metric: str, ) -> pd.DataFrame: """Update the progress of the optimization job.""" def get_last_step(df: pd.DataFrame) -> int: if df.empty: return -1 return int(df["step"].max()) if progress_bar is None or df is None: return pd.DataFrame() new_df = generate_dataframe(templates_file) last_step = get_last_step(df) new_last_step = get_last_step(new_df) if new_last_step > last_step: df_display.update(left_aligned_df_html(new_df)) update_best_display(new_df, best_textarea, best_score, eval_metric) progress_bar.value = progress_bar.value + new_last_step - last_step return new_df def create_progress_ui( self, opt_mode: str, num_opt_steps: str ) -> tuple[widgets.IntProgress, DisplayHandle, widgets.Textarea, widgets.Label]: """Create the progress UI for a specific optimization mode.""" print(f"\n\n{opt_mode} Optimization") progress_bar = widgets.IntProgress( value=0, min=0, max=int(num_opt_steps), step=1, description="Progress" ) display(progress_bar) print("\nGenerated Templates:") templates_display = display("No template is evaluated yet!", display_id=True) print("\nBest Template so far:") best_textarea = widgets.Textarea( value="NA", disabled=False, layout=widgets.Layout(width="80%", height="150px"), ) display(best_textarea) best_score = widgets.Label(value="Score: NA Improvement: NA") display(best_score) return progress_bar, templates_display, best_textarea, best_score def monitor_progress(self, job: aiplatform.CustomJob) -> bool: """Monitor the progress of the optimization job.""" self.job_state_display.update(HTML(f"<span>Job State: {job.state.name}</span>")) # Initial display of the templates. instruction_templates_file = f"{self.output_path}/instruction/templates.json" demo_templates_file = f"{self.output_path}/demonstration/templates.json" if not job.done(): self.instruction_df = self.update_progress( self.instruction_progress_bar, instruction_templates_file, self.instruction_df, self.instruction_display, self.instruction_best, self.instruction_score, self.eval_metric, ) self.demo_df = self.update_progress( self.demo_progress_bar, demo_templates_file, self.demo_df, self.demo_display, self.demo_best, self.demo_score, self.eval_metric, ) return True if job.state.name != "JOB_STATE_SUCCEEDED": errors = [f"Error: Job failed with error {job.error}."] for err_file in [ f"{self.output_path}/instruction/error.json", f"{self.output_path}/demonstration/error.json", f"{self.output_path}/error.json", ]: if gfile.exists(err_file): with gfile.GFile(err_file, "r") as f: error_json = json.load(f) errors.append(f"Detailed error: {error_json['Error']}") errors.append( f"Please feel free to send {err_file} to the VAPO team to help" " resolving the issue." ) errors.append( "All the templates found before failure can be found under" f" {self.output_path}" ) errors.append( "Please consider rerunning to make sure the failure is intransient." ) err = "\n".join(errors) err = err.replace("\n", "<br>") self.status_display.update(HTML(f'<span style="color: red;">{err}</span>')) else: self.status_display.update( HTML( '<span style="color: green;">Job succeeded!</span> <span>All the' f" artifacts can be found under {self.output_path}</span>" ) ) return False def display_dataframe(df: pd.DataFrame) -> None: """Display a pandas dataframe in Colab.""" # Function to wrap text in a scrollable div def wrap_in_scrollable_div(text: str) -> str: return f'<div class="scrollable">{text}</div>' # Apply the function to every cell using the format method styled_html = df.style.format(wrap_in_scrollable_div).to_html(index=False) # Display the HTML in the notebook display(HTML(styled_html)) def split_gcs_path(gcs_path: str) -> tuple[str, str]: """Splits a full GCS path into bucket name and prefix.""" if gcs_path.startswith("gs://"): path_without_scheme = gcs_path[5:] # Remove the 'gs://' part parts = path_without_scheme.split("/", 1) bucket_name = parts[0] prefix = parts[1] if len(parts) > 1 else "" return bucket_name, prefix raise ValueError("Invalid GCS path. Must start with 'gs://'") def list_gcs_objects(full_gcs_path: str) -> list[str]: """Lists all the objects in the given GCS path.""" bucket_name, prefix = split_gcs_path(full_gcs_path) storage_client = storage.Client() bucket = storage_client.bucket(bucket_name) blobs = bucket.list_blobs( prefix=prefix ) # List all objects that start with the prefix return [blob.name for blob in blobs] def find_directories_with_files( full_gcs_path: str, required_files: list[str] ) -> list[str]: """Finds directories containing specific files under the given full GCS path.""" bucket_name, prefix = split_gcs_path(full_gcs_path) all_paths = list_gcs_objects(f"gs://{bucket_name}/{prefix}") directories = set() # Create a dictionary to track files found in each directory file_presence: dict[str, set[str]] = {} for path in all_paths: # Get the directory part of the path directory = "/".join(path.split("/")[:-1]) filename = path.split("/")[-1] # Get the filename part of the path if directory: if directory not in file_presence: file_presence[directory] = set() file_presence[directory].add(filename) # Check which directories have all required files for directory, files in file_presence.items(): if all(file in files for file in required_files): directories.add(f"gs://{bucket_name}/{directory}") return list(directories) def extract_metric_name(metric_string: str) -> str: """Extract the metric name from a string.""" # Use a regular expression to find the metric name match = re.search(r"\.(\w+)/", metric_string) # Return the matched group if found return match.group(1) if match else metric_string def read_file_from_gcs(filename: str) -> str: """Read a file from GCS.""" with gfile.GFile(filename, "r") as f: return f.read() def process_results(df: pd.DataFrame) -> pd.DataFrame: """Process the results removing columns that could be confusing.""" columns_to_drop = [] # Dropping columns that could be confusing. for col in df.columns: if "confidence" in col: columns_to_drop.append(col) if "raw_eval_resp" in col: columns_to_drop.append(col) if col == "instruction": columns_to_drop.append(col) if col == "context": columns_to_drop.append(col) return df.drop(columns=columns_to_drop) class ResultsUI: """A UI to display the results of a VAPO run.""" def __init__(self, path: str) -> None: """Initialize the UI.""" required_files = ["eval_results.json", "templates.json"] runs = find_directories_with_files(path, required_files) self.run_label = widgets.Label("Select Run:") self.run_dropdown = widgets.Dropdown( options=runs, value=runs[0], layout=widgets.Layout(width="200px") ) self.run_dropdown.observe(self.display_run_handler, names="value") # Create a label widget for the description self.dropdown_description = widgets.Label("Select Template:") self.template_dropdown = widgets.Dropdown( options=[], value=None, layout=widgets.Layout(width="400px"), disabled=True, ) self.template_dropdown.observe(self.display_template_handler, names="value") self.results_output = widgets.Output( layout=widgets.Layout( height="600px", overflow="auto", margin="20px 0px 0px 0px" ) ) self.display_run(runs[0]) def display_template_handler(self, change: dict[str, str | None]) -> None: """Display the template and the corresponding evaluation results.""" if change["new"] is None: return df_index = int(change["new"].split(" ")[1]) self.display_eval_results(df_index) def display_run_handler(self, change: dict[str, str | None]) -> None: """Display the run and the corresponding templates.""" if change["new"] is None: return path = change["new"] self.display_run(path) def display_run(self, path: str) -> None: """Display the results of a VAPO run.""" self.run_dropdown.disabled = True filename = f"{path}/eval_results.json" eval_results = json.loads(read_file_from_gcs(filename)) filename = f"{path}/templates.json" templates = json.loads(read_file_from_gcs(filename)) if len(templates) == len(eval_results): offset = 0 elif len(templates) == len(eval_results) + 1: # In some setups it is possible to have 1 more template than results. offset = 1 else: raise ValueError( "Number of templates doesn't match number of eval results" f" {len(templates)} vs {len(eval_results)}" ) self.templates = [ pd.json_normalize(template) for template in templates[offset:] ] metric_columns = [col for col in self.templates[0].columns if "metric" in col] self.eval_results = [ process_results(pd.read_json(io.StringIO(result["metrics_table"]))) for result in eval_results ] options = [] for i, template in enumerate(self.templates): metrics = [] for col in metric_columns: value = template[col].tolist()[0] short_col = extract_metric_name(col) metrics.append(f"{short_col}: {value}") metrics_str = " ".join(metrics) options.append(f"Template {i} {metrics_str}") self.template_dropdown.disabled = False self.template_dropdown.options = options self.run_dropdown.disabled = False def display_eval_results(self, index: int) -> None: """Display the evaluation results for a specific template.""" with self.results_output: self.results_output.clear_output(wait=True) # Clear previous output display_dataframe(self.templates[index]) print() display_dataframe(self.eval_results[index]) def get_container(self) -> widgets.Output: """Get the container widget for the results UI.""" return widgets.VBox( [ self.run_label, self.run_dropdown, self.dropdown_description, self.template_dropdown, self.results_output, ] ) def get_id(length: int = 8) -> str: """Generate a uuid of a specified length (default=8).""" return "".join(random.choices(string.ascii_lowercase + string.digits, k=length)) def get_auth_token() -> str: """A function to collect the authorization token""" result = subprocess.run( ["gcloud", "auth", "print-identity-token", "-q"], capture_output=True, text=True, check=True, ) return result.stdout.strip() def init_new_model( model_name: str, generation_config: GenerationConfig | None = None, safety_settings: list[SafetySetting] | None = None, **kwargs: Any, ) -> GenerativeModel: """Initialize a new model with configurable generation and safety settings.""" if generation_config is None: generation_config = GenerationConfig( candidate_count=1, max_output_tokens=2048, temperature=0 ) if safety_settings is None: safety_settings = [ generative_models.SafetySetting( category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH, method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY, threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, ), generative_models.SafetySetting( category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY, threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, ), generative_models.SafetySetting( category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY, threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, ), generative_models.SafetySetting( category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT, method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY, threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, ), ] model = GenerativeModel( model_name=model_name, generation_config=generation_config, safety_settings=safety_settings, **kwargs, ) return model @retry(wait=wait_random_exponential(multiplier=1, max=120)) async def async_generate( prompt: str, model: GenerativeModel, function_handler: dict[str, Callable] | None = None, tools: Tool | None = None, tool_config: ToolConfig | None = None, **kwargs: Any, ) -> Union[str, None]: """Generates a response from the model, optionally handling function calls.""" user_prompt_content = Content(role="user", parts=[Part.from_text(prompt)]) try: # Initial generation - potentially calling a function. response = await model.generate_content_async( prompt, tools=[tools] if tools else None, # Only provide tools if they exist tool_config=tool_config if tool_config else None, # Same for tool_config **kwargs, ) # Handle function calls if applicable if ( function_handler and response and response.candidates and response.candidates[0].content.parts[0].function_call ): while response.candidates[0].content.parts[0].function_call: function_call = response.candidates[0].content.parts[0].function_call function_name = function_call.name if function_name in function_handler: function_args = function_call.args # No need for manual conversion api_response = function_handler[function_name](function_args) response = await model.generate_content_async( [ user_prompt_content, response.candidates[0].content, Content( parts=[ Part.from_function_response( name=function_name, response={"content": api_response}, ) ] ), ], tools=[tools] if tools else None, # Conditional tool passing tool_config=tool_config if tool_config else None, ) else: break # Exit loop if function not found # Extract and return text if generation was successful if response and response.candidates and response.candidates[0].content.parts: return ( response.candidates[0].content.parts[0].text ) # More robust text extraction return None except Exception as e: # pylint: disable=broad-except print(f"Error calling the model: {e}") # Include the actual error message return "Could not call the model. Please try it again in a few minutes." # pylint: disable=too-many-positional-arguments,too-many-arguments def evaluate_task( df: pd.DataFrame, prompt_col: str, reference_col: str, response_col: str, experiment_name: str, eval_metrics: list[str], eval_sample_n: int, ) -> dict[str, float]: """Evaluate task using Vertex AI Evaluation.""" # Generate a unique id for the experiment run idx = get_id() # Rename the columns to match the expected format eval_dataset = df[[prompt_col, reference_col, response_col]].rename( columns={ prompt_col: "prompt", reference_col: "reference", response_col: "response", } ) # Drop rows with missing values eval_dataset = eval_dataset.dropna() # Sample a subset of the dataset eval_dataset = eval_dataset.sample(n=eval_sample_n, random_state=8).reset_index( drop=True ) # Create an EvalTask object eval_task = EvalTask( dataset=eval_dataset, metrics=eval_metrics, experiment=experiment_name, ) # Evaluate the task result = eval_task.evaluate(experiment_run_name=f"{experiment_name}-{idx}") # Return the summary metrics return result.summary_metrics def print_df_rows( df: pd.DataFrame, columns: list[str] | None = None, n: int = 3 ) -> None: """Print a subset of rows from a DataFrame.""" # Apply column filtering if specified if columns: df = df[columns] # Style definitions for improved readability base_style = ( "font-family: monospace; font-size: 14px; white-space: pre-wrap; width:" " auto; overflow-x: auto;" ) header_style = base_style + "font-weight: bold;" # Iterate through the specified number of rows for _, row in df.head(n).iterrows(): # Display each column name as a bold header for column in df.columns: display( HTML( "<span" f" style='{header_style}'>{column.replace('_', ' ').title()}:" " </span>" ) ) display( HTML(f"<span style='{base_style}'>{row[column]}</span><br>") ) # Display value and line break display(HTML("<hr>")) # Add separator between rows for clarity def plot_eval_metrics( eval_results: list[tuple[str, dict[str, float]]], metrics: list[str] | None = None, ) -> None: """Plot a bar plot for the evaluation results.""" # Create data for the bar plot data = [] for eval_result in eval_results: title, summary_metrics = eval_result if metrics: summary_metrics = { k: summary_metrics[k] for k, v in summary_metrics.items() if any(selected_metric in k for selected_metric in metrics) } summary_metrics = {k: v for k, v in summary_metrics.items() if "mean" in k} data.append( go.Bar( x=list(summary_metrics.keys()), y=list(summary_metrics.values()), name=title, ) ) # Update the figure with the data fig = go.Figure(data=data) # Add the title fig.update_layout( title=go.layout.Title(text="Evaluation Metrics", x=0.5), xaxis_title="Metric Name", yaxis_title="Mean Value", ) # Change the bar mode fig.update_layout(barmode="group") # Show the plot fig.show() def create_target_column(row: dict[str, Any]) -> str: """Creates a JSON string representing tool calls from input row.""" tool_calls = ( [{"name": row["tool_names"], "arguments": row["tool_arguments"]}] if row.get("tool_names") else [] ) return json.dumps({"content": "", "tool_calls": tool_calls}) def tool_config_to_dict(tool_config: ToolConfig | None) -> dict[str, Any] | None: """Converts a ToolConfig object to a dictionary.""" if tool_config is None: return None # pylint: disable=protected-access config = tool_config._gapic_tool_config.function_calling_config return { "function_calling_config": { "mode": config.mode.name, "allowed_function_names": list(config.allowed_function_names), } } def replace_type_key(data: dict[str, Any]) -> dict[str, Any]: """Recursively replaces "type_" with "type" in a dictionary or list.""" def _recursive_replace(item: Any) -> Any: if isinstance(item, dict): return { ("type" if k == "type_" else k): _recursive_replace(v) for k, v in item.items() } elif isinstance(item, list): return [_recursive_replace(elem) for elem in item] else: return item new_data = {} for key, value in data.items(): if key == "function_declarations" and isinstance(value, list): new_data[key] = [_recursive_replace(tool) for tool in value] else: new_data[key] = value return new_data def validate_tools(spec: str) -> None: """Validates the tools specification.""" # Define the JSON schema for validation schema = { "type": "object", "properties": { "tools": { "type": "array", "minItems": 1, # Ensures that 'tools' is not an empty array "items": { "type": "object", "properties": { "function_declarations": { "type": "array", # Ensures this is not an empty array "minItems": 1, "items": { "type": "object", "properties": { "name": {"type": "string"}, "description": {"type": "string"}, "parameters": { "type": "object", "properties": { "type": {"type": "string"}, "properties": {"type": "object"}, "required": { "type": "array", "items": {"type": "string"}, }, }, "required": ["type", "properties"], }, }, "required": ["name", "description", "parameters"], }, } }, "required": ["function_declarations"], }, } }, "required": ["tools"], } json_spec = json.loads(spec) try: # Validate the JSON specification against the schema validate(instance=json_spec, schema=schema) except ValidationError as e: raise ValueError(f"Invalid Tools specification: {e}") from e def validate_tool_config(tool_config: str) -> None: """Validates the format of the tool_config.""" schema = { "type": "object", "properties": { "function_calling_config": { "type": "object", "properties": { "mode": {"type": "string", "enum": ["AUTO", "ANY", "NONE"]}, "allowed_function_names": { "type": "array", "items": {"type": "string"}, }, }, "required": ["mode"], } }, "required": ["function_calling_config"], } try: validate(instance=json.loads(tool_config), schema=schema) except ValidationError as e: raise ValueError(f"Invalid tool_config: {tool_config}") from e