# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=too-many-lines

"""Utility functions and classes for the VAPO notebook."""
import csv
import io
import json
import random
import re
import string
import subprocess
from typing import Any, Callable, Union

from IPython.core.display import DisplayHandle
from IPython.display import HTML, display
from google.cloud import aiplatform, storage
import ipywidgets as widgets
import jinja2
import jinja2.meta
from jsonschema import ValidationError, validate
import pandas as pd
import plotly.graph_objects as go
from tenacity import retry, wait_random_exponential
from tensorflow.io import gfile
from vertexai import generative_models
from vertexai.evaluation import EvalTask
from vertexai.generative_models import (
    Content,
    GenerationConfig,
    GenerativeModel,
    Part,
    SafetySetting,
    Tool,
    ToolConfig,
)


def is_target_required_metric(eval_metric: str) -> bool:
    """Check if the metric requires the target label."""
    return eval_metric in [
        "bleu",
        "exact_match",
        "question_answering_correctness",
        "rouge_1",
        "rouge_2",
        "rouge_l",
        "rouge_l_sum",
        "tool_call_valid",
        "tool_name_match",
        "tool_parameter_key_match",
        "tool_parameter_kv_match",
    ]


def is_run_target_required(eval_metric_types: list[str], source_model: str) -> bool:
    """Check if the run requires the target label."""
    if source_model:
        return False

    label_required = False
    for metric in eval_metric_types:
        label_required = label_required or is_target_required_metric(metric)
    return label_required


_TARGET_KEY = "target"


def load_file_from_gcs(dataset: str) -> str:
    """Loads the file from GCS and returns it as a string."""
    if dataset.startswith("gs://"):
        with gfile.GFile(dataset, "r") as f:
            return f.read()
    else:
        raise ValueError(
            "Unsupported file location. Only GCS paths starting with 'gs://' are"
            " supported."
        )


def parse_jsonl(data_str: str) -> list[dict[str, str]]:
    """Parses the content of a JSONL file and returns a list of dictionaries."""
    data = []
    lines = data_str.splitlines()
    for line in lines:
        if line:
            try:
                data.append(json.loads(line))
            except json.JSONDecodeError as e:
                raise ValueError(
                    f"Error decoding JSON on line: {line}. Error: {e}"
                ) from e
    return data


def parse_and_validate_csv(data_str: str) -> list[dict[str, str]]:
    """Parses and validates the content of a CSV file and returns a list of dictionaries."""
    data = []
    csv_reader = csv.reader(io.StringIO(data_str))

    # Extract and validate headers
    try:
        headers = next(csv_reader)
        if not headers:
            raise ValueError("The CSV file has an empty or invalid header row.")
    except StopIteration as e:
        raise ValueError("The CSV file is empty.") from e

    # Validate and process rows
    for row_number, row in enumerate(csv_reader, start=2):
        if len(row) != len(headers):
            raise ValueError(
                f"Row {row_number} has an inconsistent number of fields. "
                f"Expected {len(headers)} fields but found {len(row)}."
            )
        # Create dictionary for each row using headers as keys
        item = dict(zip(headers, row))
        data.append(item)

    return data


def load_dataset(dataset: str) -> list[dict[str, str]]:
    """Loads and parses the dataset based on its file type ('.jsonl' or '.csv')."""
    # Load the file from GCS
    data_str = load_file_from_gcs(dataset)

    # Parse based on file type
    if dataset.endswith(".jsonl"):
        return parse_jsonl(data_str)

    if dataset.endswith(".csv"):
        return parse_and_validate_csv(data_str)

    raise ValueError(
        "Unsupported file type. Please provide a file with '.jsonl' or '.csv'"
        " extension."
    )


def validate_prompt_and_data(
    template: str,
    dataset_path: str,
    placeholder_to_content: str,
    label_enforced: bool,
) -> None:
    """Validates the prompt template and the dataset."""
    data = load_dataset(dataset_path)
    placeholder_to_content_json = json.loads(placeholder_to_content)
    template = re.sub(r"(?<!{){(?!{)(\s*\w+\s*)(?<!})}(?!})", r"{{\1}}", template)
    env = jinja2.Environment()
    try:
        parsed_content = env.parse(template)
    except jinja2.exceptions.TemplateSyntaxError as e:
        raise ValueError(f"Invalid template: {template}") from e

    template_variables = jinja2.meta.find_undeclared_variables(parsed_content)
    extra_keys = set()
    for ex in data:
        ex.update(placeholder_to_content_json)
        missing_keys = [key for key in template_variables if key not in ex]
        extra_keys.update([key for key in ex if key not in template_variables])
        if label_enforced:
            if _TARGET_KEY not in ex:
                raise ValueError(
                    f"The example {ex} doesn't have a key corresponding to the target"
                    f" var: {_TARGET_KEY}"
                )
            if not ex[_TARGET_KEY]:
                raise ValueError(f"The following example has an empty target: {ex}")
        if missing_keys:
            raise ValueError(
                f"The example {ex} doesn't have a key corresponding to following"
                f" template vars: {missing_keys}"
            )
    if extra_keys:
        raise Warning(
            "Warning: extra keys in the examples not used in the prompt template"
            f" template {extra_keys}"
        )


def run_custom_job(
    display_name: str,
    container_uri: str,
    container_args: dict[str, str],
) -> str:
    """A sample to create custom jobs."""
    worker_pool_specs = [
        {
            "replica_count": 1,
            "container_spec": {
                "image_uri": container_uri,
                "args": [f"--{k}={v}" for k, v in container_args.items()],
            },
            "machine_spec": {
                "machine_type": "n1-standard-4",
            },
        }
    ]

    custom_job = aiplatform.CustomJob(
        display_name=display_name,
        worker_pool_specs=worker_pool_specs,
    )
    custom_job.submit()
    return custom_job


def run_apd(config: dict[str, str], bucket_uri: str, display_name: str) -> str:
    """A function to the vertex prompt optimizer."""
    print(f"\n\nJob display name: {display_name}")
    version = "preview_v1_0"
    container_uri = "us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/apd"
    config_path = f"{bucket_uri}/{display_name}/input_config.json"

    with gfile.GFile(config_path, "w") as f:
        json.dump(config, f)

    aiplatform.init(
        project=config["project"],
        location=config["target_model_location"],
        staging_bucket=f"{bucket_uri}/{display_name}",
    )

    return run_custom_job(
        display_name=display_name,
        container_uri=f"{container_uri}:{version}",
        container_args={"config": config_path},
    )


def update_best_display(
    df: pd.DataFrame,
    textarea: widgets.Textarea,
    best_score_label: widgets.Label,
    eval_metric: str,
) -> None:
    """Update the best prompt display."""

    df["score"] = df[f"metrics.{eval_metric}/mean"]

    best_template = df.loc[df["score"].argmax(), "prompt"]
    best_score = df.loc[df["score"].argmax(), "score"]
    original_score = df.loc[0, "score"]

    def placeholder_llm() -> str:
        return "{{llm()}}"

    env = jinja2.Environment(loader=jinja2.BaseLoader())
    env.globals["llm"] = placeholder_llm

    best_template = best_template.replace("store('answer', llm())", "llm()")
    textarea.value = best_template
    improvement = best_score - original_score
    no_improvement_str = "\nNo better template is found yet." if not improvement else ""
    best_score_label.value = (
        f"Score: {best_score}" f" Improvement: {improvement: .3f} {no_improvement_str}"
    )


def generate_dataframe(filename: str) -> pd.DataFrame:
    """Generates a pandas dataframe from a json file."""
    if not gfile.exists(filename):
        return pd.DataFrame()

    with gfile.GFile(filename, "r") as f:
        try:
            data = json.load(f)
        except json.JSONDecodeError:
            return pd.DataFrame()
    return pd.json_normalize(data)


def left_aligned_df_html(df: pd.DataFrame) -> HTML:
    """Displays a Pandas DataFrame in Colab with left-aligned values."""

    # Convert to HTML table, but keep the HTML in a variable
    html_table = df.to_html(index=False, classes="left-aligned")

    # Add CSS styling to left-align table data cells and override default styles
    styled_html = f"""
    <style>
        .left-aligned td, .left-aligned th {{ text-align: left !important; }}
    </style>
    {html_table}
    """

    # Display the styled HTML table
    return HTML(styled_html)


def extract_top_level_function_name(source_code: str) -> str | None:
    """Extract the top level function name from the source code."""
    match = re.search(r"^def\s+([a-zA-Z_]\w*)\s*\(", source_code, re.MULTILINE)
    if match:
        return match.group(1)
    return None


class ProgressForm:
    """A class to display the progress of the optimization job."""

    # pylint: disable=too-many-instance-attributes

    def __init__(self, params: dict[str, str]) -> None:
        """Initialize the progress form."""
        self.instruction_progress_bar = None
        self.instruction_display = None
        self.instruction_best = None
        self.instruction_score = None
        self.demo_progress_bar = None
        self.demo_display = None
        self.demo_best = None
        self.demo_score = None

        self.job_state_display = display(
            HTML("<span>Job State: Not Started!</span>"), display_id=True
        )
        self.status_display = display(HTML(""), display_id=True)

        if params["optimization_mode"] in ["instruction", "instruction_and_demo"]:
            (
                self.instruction_progress_bar,
                self.instruction_display,
                self.instruction_best,
                self.instruction_score,
            ) = self.create_progress_ui("Instruction", params["num_steps"])

        if params["optimization_mode"] in ["demonstration", "instruction_and_demo"]:
            (
                self.demo_progress_bar,
                self.demo_display,
                self.demo_best,
                self.demo_score,
            ) = self.create_progress_ui(
                "Demonstration", params["num_demo_set_candidates"]
            )

        if len(params["eval_metrics_types"]) == 1:
            self.eval_metric = params["eval_metrics_types"][0]
        else:
            self.eval_metric = "composite_metric"

        self.output_path = params["output_path"]
        self.instruction_df = None
        self.demo_df = None

    # pylint: disable=too-many-positional-arguments,too-many-arguments
    def update_progress(
        self,
        progress_bar: widgets.IntProgress | None,
        templates_file: str,
        df: pd.DataFrame | None,
        df_display: DisplayHandle,
        best_textarea: widgets.Textarea,
        best_score: widgets.Label,
        eval_metric: str,
    ) -> pd.DataFrame:
        """Update the progress of the optimization job."""

        def get_last_step(df: pd.DataFrame) -> int:
            if df.empty:
                return -1
            return int(df["step"].max())

        if progress_bar is None or df is None:
            return pd.DataFrame()

        new_df = generate_dataframe(templates_file)

        last_step = get_last_step(df)
        new_last_step = get_last_step(new_df)
        if new_last_step > last_step:
            df_display.update(left_aligned_df_html(new_df))
            update_best_display(new_df, best_textarea, best_score, eval_metric)
            progress_bar.value = progress_bar.value + new_last_step - last_step

        return new_df

    def create_progress_ui(
        self, opt_mode: str, num_opt_steps: str
    ) -> tuple[widgets.IntProgress, DisplayHandle, widgets.Textarea, widgets.Label]:
        """Create the progress UI for a specific optimization mode."""
        print(f"\n\n{opt_mode} Optimization")
        progress_bar = widgets.IntProgress(
            value=0, min=0, max=int(num_opt_steps), step=1, description="Progress"
        )
        display(progress_bar)
        print("\nGenerated Templates:")
        templates_display = display("No template is evaluated yet!", display_id=True)

        print("\nBest Template so far:")
        best_textarea = widgets.Textarea(
            value="NA",
            disabled=False,
            layout=widgets.Layout(width="80%", height="150px"),
        )
        display(best_textarea)

        best_score = widgets.Label(value="Score: NA Improvement: NA")
        display(best_score)

        return progress_bar, templates_display, best_textarea, best_score

    def monitor_progress(self, job: aiplatform.CustomJob) -> bool:
        """Monitor the progress of the optimization job."""
        self.job_state_display.update(HTML(f"<span>Job State: {job.state.name}</span>"))

        # Initial display of the templates.
        instruction_templates_file = f"{self.output_path}/instruction/templates.json"
        demo_templates_file = f"{self.output_path}/demonstration/templates.json"

        if not job.done():
            self.instruction_df = self.update_progress(
                self.instruction_progress_bar,
                instruction_templates_file,
                self.instruction_df,
                self.instruction_display,
                self.instruction_best,
                self.instruction_score,
                self.eval_metric,
            )
            self.demo_df = self.update_progress(
                self.demo_progress_bar,
                demo_templates_file,
                self.demo_df,
                self.demo_display,
                self.demo_best,
                self.demo_score,
                self.eval_metric,
            )
            return True

        if job.state.name != "JOB_STATE_SUCCEEDED":
            errors = [f"Error: Job failed with error {job.error}."]
            for err_file in [
                f"{self.output_path}/instruction/error.json",
                f"{self.output_path}/demonstration/error.json",
                f"{self.output_path}/error.json",
            ]:
                if gfile.exists(err_file):
                    with gfile.GFile(err_file, "r") as f:
                        error_json = json.load(f)
                    errors.append(f"Detailed error: {error_json['Error']}")
                    errors.append(
                        f"Please feel free to send {err_file} to the VAPO team to help"
                        " resolving the issue."
                    )

            errors.append(
                "All the templates found before failure can be found under"
                f" {self.output_path}"
            )
            errors.append(
                "Please consider rerunning to make sure the failure is intransient."
            )
            err = "\n".join(errors)
            err = err.replace("\n", "<br>")
            self.status_display.update(HTML(f'<span style="color: red;">{err}</span>'))
        else:
            self.status_display.update(
                HTML(
                    '<span style="color: green;">Job succeeded!</span> <span>All the'
                    f" artifacts can be found under {self.output_path}</span>"
                )
            )
        return False


def display_dataframe(df: pd.DataFrame) -> None:
    """Display a pandas dataframe in Colab."""

    # Function to wrap text in a scrollable div
    def wrap_in_scrollable_div(text: str) -> str:
        return f'<div class="scrollable">{text}</div>'

    # Apply the function to every cell using the format method
    styled_html = df.style.format(wrap_in_scrollable_div).to_html(index=False)

    # Display the HTML in the notebook
    display(HTML(styled_html))


def split_gcs_path(gcs_path: str) -> tuple[str, str]:
    """Splits a full GCS path into bucket name and prefix."""
    if gcs_path.startswith("gs://"):
        path_without_scheme = gcs_path[5:]  # Remove the 'gs://' part
        parts = path_without_scheme.split("/", 1)
        bucket_name = parts[0]
        prefix = parts[1] if len(parts) > 1 else ""
        return bucket_name, prefix

    raise ValueError("Invalid GCS path. Must start with 'gs://'")


def list_gcs_objects(full_gcs_path: str) -> list[str]:
    """Lists all the objects in the given GCS path."""
    bucket_name, prefix = split_gcs_path(full_gcs_path)
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blobs = bucket.list_blobs(
        prefix=prefix
    )  # List all objects that start with the prefix

    return [blob.name for blob in blobs]


def find_directories_with_files(
    full_gcs_path: str, required_files: list[str]
) -> list[str]:
    """Finds directories containing specific files under the given full GCS path."""
    bucket_name, prefix = split_gcs_path(full_gcs_path)
    all_paths = list_gcs_objects(f"gs://{bucket_name}/{prefix}")
    directories = set()

    # Create a dictionary to track files found in each directory
    file_presence: dict[str, set[str]] = {}
    for path in all_paths:
        # Get the directory part of the path
        directory = "/".join(path.split("/")[:-1])
        filename = path.split("/")[-1]  # Get the filename part of the path
        if directory:
            if directory not in file_presence:
                file_presence[directory] = set()
            file_presence[directory].add(filename)

    # Check which directories have all required files
    for directory, files in file_presence.items():
        if all(file in files for file in required_files):
            directories.add(f"gs://{bucket_name}/{directory}")

    return list(directories)


def extract_metric_name(metric_string: str) -> str:
    """Extract the metric name from a string."""
    # Use a regular expression to find the metric name
    match = re.search(r"\.(\w+)/", metric_string)
    # Return the matched group if found
    return match.group(1) if match else metric_string


def read_file_from_gcs(filename: str) -> str:
    """Read a file from GCS."""
    with gfile.GFile(filename, "r") as f:
        return f.read()


def process_results(df: pd.DataFrame) -> pd.DataFrame:
    """Process the results removing columns that could be confusing."""
    columns_to_drop = []
    # Dropping columns that could be confusing.
    for col in df.columns:
        if "confidence" in col:
            columns_to_drop.append(col)
        if "raw_eval_resp" in col:
            columns_to_drop.append(col)
        if col == "instruction":
            columns_to_drop.append(col)
        if col == "context":
            columns_to_drop.append(col)
    return df.drop(columns=columns_to_drop)


class ResultsUI:
    """A UI to display the results of a VAPO run."""

    def __init__(self, path: str) -> None:
        """Initialize the UI."""
        required_files = ["eval_results.json", "templates.json"]
        runs = find_directories_with_files(path, required_files)

        self.run_label = widgets.Label("Select Run:")
        self.run_dropdown = widgets.Dropdown(
            options=runs, value=runs[0], layout=widgets.Layout(width="200px")
        )
        self.run_dropdown.observe(self.display_run_handler, names="value")

        # Create a label widget for the description
        self.dropdown_description = widgets.Label("Select Template:")
        self.template_dropdown = widgets.Dropdown(
            options=[],
            value=None,
            layout=widgets.Layout(width="400px"),
            disabled=True,
        )
        self.template_dropdown.observe(self.display_template_handler, names="value")
        self.results_output = widgets.Output(
            layout=widgets.Layout(
                height="600px", overflow="auto", margin="20px 0px 0px 0px"
            )
        )
        self.display_run(runs[0])

    def display_template_handler(self, change: dict[str, str | None]) -> None:
        """Display the template and the corresponding evaluation results."""
        if change["new"] is None:
            return

        df_index = int(change["new"].split(" ")[1])
        self.display_eval_results(df_index)

    def display_run_handler(self, change: dict[str, str | None]) -> None:
        """Display the run and the corresponding templates."""
        if change["new"] is None:
            return

        path = change["new"]
        self.display_run(path)

    def display_run(self, path: str) -> None:
        """Display the results of a VAPO run."""
        self.run_dropdown.disabled = True
        filename = f"{path}/eval_results.json"
        eval_results = json.loads(read_file_from_gcs(filename))

        filename = f"{path}/templates.json"
        templates = json.loads(read_file_from_gcs(filename))

        if len(templates) == len(eval_results):
            offset = 0
        elif len(templates) == len(eval_results) + 1:
            # In some setups it is possible to have 1 more template than results.
            offset = 1
        else:
            raise ValueError(
                "Number of templates doesn't match number of eval results"
                f" {len(templates)} vs {len(eval_results)}"
            )
        self.templates = [
            pd.json_normalize(template) for template in templates[offset:]
        ]
        metric_columns = [col for col in self.templates[0].columns if "metric" in col]

        self.eval_results = [
            process_results(pd.read_json(io.StringIO(result["metrics_table"])))
            for result in eval_results
        ]
        options = []
        for i, template in enumerate(self.templates):
            metrics = []
            for col in metric_columns:
                value = template[col].tolist()[0]
                short_col = extract_metric_name(col)
                metrics.append(f"{short_col}: {value}")
            metrics_str = " ".join(metrics)
            options.append(f"Template {i} {metrics_str}")

        self.template_dropdown.disabled = False
        self.template_dropdown.options = options
        self.run_dropdown.disabled = False

    def display_eval_results(self, index: int) -> None:
        """Display the evaluation results for a specific template."""
        with self.results_output:
            self.results_output.clear_output(wait=True)  # Clear previous output
            display_dataframe(self.templates[index])
            print()
            display_dataframe(self.eval_results[index])

    def get_container(self) -> widgets.Output:
        """Get the container widget for the results UI."""
        return widgets.VBox(
            [
                self.run_label,
                self.run_dropdown,
                self.dropdown_description,
                self.template_dropdown,
                self.results_output,
            ]
        )


def get_id(length: int = 8) -> str:
    """Generate a uuid of a specified length (default=8)."""
    return "".join(random.choices(string.ascii_lowercase + string.digits, k=length))


def get_auth_token() -> str:
    """A function to collect the authorization token"""
    result = subprocess.run(
        ["gcloud", "auth", "print-identity-token", "-q"],
        capture_output=True,
        text=True,
        check=True,
    )
    return result.stdout.strip()


def init_new_model(
    model_name: str,
    generation_config: GenerationConfig | None = None,
    safety_settings: list[SafetySetting] | None = None,
    **kwargs: Any,
) -> GenerativeModel:
    """Initialize a new model with configurable generation and safety settings."""

    if generation_config is None:
        generation_config = GenerationConfig(
            candidate_count=1, max_output_tokens=2048, temperature=0
        )
    if safety_settings is None:
        safety_settings = [
            generative_models.SafetySetting(
                category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
                method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY,
                threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
            ),
            generative_models.SafetySetting(
                category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
                method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY,
                threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
            ),
            generative_models.SafetySetting(
                category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
                method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY,
                threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
            ),
            generative_models.SafetySetting(
                category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT,
                method=generative_models.SafetySetting.HarmBlockMethod.SEVERITY,
                threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
            ),
        ]

    model = GenerativeModel(
        model_name=model_name,
        generation_config=generation_config,
        safety_settings=safety_settings,
        **kwargs,
    )
    return model


@retry(wait=wait_random_exponential(multiplier=1, max=120))
async def async_generate(
    prompt: str,
    model: GenerativeModel,
    function_handler: dict[str, Callable] | None = None,
    tools: Tool | None = None,
    tool_config: ToolConfig | None = None,
    **kwargs: Any,
) -> Union[str, None]:
    """Generates a response from the model, optionally handling function calls."""

    user_prompt_content = Content(role="user", parts=[Part.from_text(prompt)])

    try:
        # Initial generation - potentially calling a function.
        response = await model.generate_content_async(
            prompt,
            tools=[tools] if tools else None,  # Only provide tools if they exist
            tool_config=tool_config if tool_config else None,  # Same for tool_config
            **kwargs,
        )

        # Handle function calls if applicable
        if (
            function_handler
            and response
            and response.candidates
            and response.candidates[0].content.parts[0].function_call
        ):
            while response.candidates[0].content.parts[0].function_call:
                function_call = response.candidates[0].content.parts[0].function_call
                function_name = function_call.name

                if function_name in function_handler:
                    function_args = function_call.args  # No need for manual conversion
                    api_response = function_handler[function_name](function_args)

                    response = await model.generate_content_async(
                        [
                            user_prompt_content,
                            response.candidates[0].content,
                            Content(
                                parts=[
                                    Part.from_function_response(
                                        name=function_name,
                                        response={"content": api_response},
                                    )
                                ]
                            ),
                        ],
                        tools=[tools] if tools else None,  # Conditional tool passing
                        tool_config=tool_config if tool_config else None,
                    )
                else:
                    break  # Exit loop if function not found

        # Extract and return text if generation was successful
        if response and response.candidates and response.candidates[0].content.parts:
            return (
                response.candidates[0].content.parts[0].text
            )  # More robust text extraction
        return None

    except Exception as e:  # pylint: disable=broad-except
        print(f"Error calling the model: {e}")  # Include the actual error message
        return "Could not call the model. Please try it again in a few minutes."


# pylint: disable=too-many-positional-arguments,too-many-arguments
def evaluate_task(
    df: pd.DataFrame,
    prompt_col: str,
    reference_col: str,
    response_col: str,
    experiment_name: str,
    eval_metrics: list[str],
    eval_sample_n: int,
) -> dict[str, float]:
    """Evaluate task using Vertex AI Evaluation."""

    # Generate a unique id for the experiment run
    idx = get_id()

    # Rename the columns to match the expected format
    eval_dataset = df[[prompt_col, reference_col, response_col]].rename(
        columns={
            prompt_col: "prompt",
            reference_col: "reference",
            response_col: "response",
        }
    )

    # Drop rows with missing values
    eval_dataset = eval_dataset.dropna()

    # Sample a subset of the dataset
    eval_dataset = eval_dataset.sample(n=eval_sample_n, random_state=8).reset_index(
        drop=True
    )

    # Create an EvalTask object
    eval_task = EvalTask(
        dataset=eval_dataset,
        metrics=eval_metrics,
        experiment=experiment_name,
    )

    # Evaluate the task
    result = eval_task.evaluate(experiment_run_name=f"{experiment_name}-{idx}")

    # Return the summary metrics
    return result.summary_metrics


def print_df_rows(
    df: pd.DataFrame, columns: list[str] | None = None, n: int = 3
) -> None:
    """Print a subset of rows from a DataFrame."""

    # Apply column filtering if specified
    if columns:
        df = df[columns]

    # Style definitions for improved readability
    base_style = (
        "font-family: monospace; font-size: 14px; white-space: pre-wrap; width:"
        " auto; overflow-x: auto;"
    )
    header_style = base_style + "font-weight: bold;"

    # Iterate through the specified number of rows
    for _, row in df.head(n).iterrows():
        # Display each column name as a bold header
        for column in df.columns:
            display(
                HTML(
                    "<span"
                    f" style='{header_style}'>{column.replace('_', ' ').title()}:"
                    " </span>"
                )
            )
            display(
                HTML(f"<span style='{base_style}'>{row[column]}</span><br>")
            )  # Display value and line break
        display(HTML("<hr>"))  # Add separator between rows for clarity


def plot_eval_metrics(
    eval_results: list[tuple[str, dict[str, float]]],
    metrics: list[str] | None = None,
) -> None:
    """Plot a bar plot for the evaluation results."""

    # Create data for the bar plot
    data = []
    for eval_result in eval_results:
        title, summary_metrics = eval_result
        if metrics:
            summary_metrics = {
                k: summary_metrics[k]
                for k, v in summary_metrics.items()
                if any(selected_metric in k for selected_metric in metrics)
            }

        summary_metrics = {k: v for k, v in summary_metrics.items() if "mean" in k}
        data.append(
            go.Bar(
                x=list(summary_metrics.keys()),
                y=list(summary_metrics.values()),
                name=title,
            )
        )

    # Update the figure with the data
    fig = go.Figure(data=data)

    # Add the title
    fig.update_layout(
        title=go.layout.Title(text="Evaluation Metrics", x=0.5),
        xaxis_title="Metric Name",
        yaxis_title="Mean Value",
    )

    # Change the bar mode
    fig.update_layout(barmode="group")

    # Show the plot
    fig.show()


def create_target_column(row: dict[str, Any]) -> str:
    """Creates a JSON string representing tool calls from input row."""

    tool_calls = (
        [{"name": row["tool_names"], "arguments": row["tool_arguments"]}]
        if row.get("tool_names")
        else []
    )

    return json.dumps({"content": "", "tool_calls": tool_calls})


def tool_config_to_dict(tool_config: ToolConfig | None) -> dict[str, Any] | None:
    """Converts a ToolConfig object to a dictionary."""

    if tool_config is None:
        return None

    # pylint: disable=protected-access
    config = tool_config._gapic_tool_config.function_calling_config
    return {
        "function_calling_config": {
            "mode": config.mode.name,
            "allowed_function_names": list(config.allowed_function_names),
        }
    }


def replace_type_key(data: dict[str, Any]) -> dict[str, Any]:
    """Recursively replaces "type_" with "type" in a dictionary or list."""

    def _recursive_replace(item: Any) -> Any:
        if isinstance(item, dict):
            return {
                ("type" if k == "type_" else k): _recursive_replace(v)
                for k, v in item.items()
            }
        elif isinstance(item, list):
            return [_recursive_replace(elem) for elem in item]
        else:
            return item

    new_data = {}
    for key, value in data.items():
        if key == "function_declarations" and isinstance(value, list):
            new_data[key] = [_recursive_replace(tool) for tool in value]
        else:
            new_data[key] = value

    return new_data


def validate_tools(spec: str) -> None:
    """Validates the tools specification."""
    # Define the JSON schema for validation
    schema = {
        "type": "object",
        "properties": {
            "tools": {
                "type": "array",
                "minItems": 1,  # Ensures that 'tools' is not an empty array
                "items": {
                    "type": "object",
                    "properties": {
                        "function_declarations": {
                            "type": "array",
                            # Ensures this is not an empty array
                            "minItems": 1,
                            "items": {
                                "type": "object",
                                "properties": {
                                    "name": {"type": "string"},
                                    "description": {"type": "string"},
                                    "parameters": {
                                        "type": "object",
                                        "properties": {
                                            "type": {"type": "string"},
                                            "properties": {"type": "object"},
                                            "required": {
                                                "type": "array",
                                                "items": {"type": "string"},
                                            },
                                        },
                                        "required": ["type", "properties"],
                                    },
                                },
                                "required": ["name", "description", "parameters"],
                            },
                        }
                    },
                    "required": ["function_declarations"],
                },
            }
        },
        "required": ["tools"],
    }

    json_spec = json.loads(spec)
    try:
        # Validate the JSON specification against the schema
        validate(instance=json_spec, schema=schema)
    except ValidationError as e:
        raise ValueError(f"Invalid Tools specification: {e}") from e


def validate_tool_config(tool_config: str) -> None:
    """Validates the format of the tool_config."""

    schema = {
        "type": "object",
        "properties": {
            "function_calling_config": {
                "type": "object",
                "properties": {
                    "mode": {"type": "string", "enum": ["AUTO", "ANY", "NONE"]},
                    "allowed_function_names": {
                        "type": "array",
                        "items": {"type": "string"},
                    },
                },
                "required": ["mode"],
            }
        },
        "required": ["function_calling_config"],
    }

    try:
        validate(instance=json.loads(tool_config), schema=schema)
    except ValidationError as e:
        raise ValueError(f"Invalid tool_config: {tool_config}") from e
