# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Defintion of all the gemini function calling methods which are used
by the nl2sql agent.
"""

import textwrap
from vertexai.generative_models import FunctionDeclaration


list_tables_func = FunctionDeclaration(
    name="list_tables",
    description="List tables in a dataset that will help answer the\
          user's question by choosing the table first. This will ensure,\
              we do not create query on wrong table.",
    parameters={
        "type": "object",
    },
)

get_table_metadata_func = FunctionDeclaration(
    name="get_table_metadata",
    description="Get information about a table, including the description,\
          schema, and number of rows that will help answer the user's\
          question. Always use the fully qualified dataset and\
          table names.",
    parameters={
        "type": "object",
        "properties": {
            "table_id": {
                "type": "string",
                "description": "Fully qualified ID of the table to get\
                      information about",
            }
        },
        "required": [
            "table_id",
        ],
    },
)

sql_query_func = FunctionDeclaration(
    name="sql_query",
    description="""
    Get information from data in BigQuery using SQL queries.
    Also generates SQL by observing the error output from previous
    SQL execution if SQL query fails to execute
    """,
    parameters={
        "type": "object",
        "properties": {
            "query": {
                "type": "string",
                "description": '''
                SQL query on a single line that will help give quantitative
                answers to the user's question when run on a BigQuery dataset
                and table. In the SQL query, always use the fully qualified
                dataset and table names. Do not use hardcoded dates in query,
                always use date or time functiond instead.
                If SQL query failed to execute in past, then observe the past
                error and correct that and then create the SQL.''',
            }
        },
        "required": [
            "query",
        ],
    },
)

self_debug_func = FunctionDeclaration(
    name="debug_sql_query",
    description="Generates SQL by observing the error output from previous\
          SQL execution if SQL query fails to execute",
    parameters={
        "type": "object",
        "properties": {
            "query": {
                "type": "string",
                "description": '''If SQL query failed to execute in past,
                  then observe the past error and correct that and then
                  create the SQL. In the SQL query, always use the fully
                  qualified dataset and table names. Do not use hardcoded
                  dates in query, always use date or time functiond instead.
                  ''',
            }
        },
        "required": [
            "query",
        ],
    },
)

plot_chart_auto_func = FunctionDeclaration(
    name="plot_chart_auto",
    description="extract the data from the SQL query output and writes\
          python code to plot charts to best visualize that data.",
    parameters={
        "type": "object",
        "properties": {
            "code": {
                "type": "string",
                "description": textwrap.dedent('''
First extract the neseccary data, best suitable chart-type for that data,
title, axis and any other required information to plot the chart successfully.
Get the data in such a way that it does not take too many repetitions or
large number of characters. Then write python code to create a plot using this
information in plotly module.
Finally the code should store the plot in fig variable and do NOT do
fig.show().
Example:
```
import plotly.express as px
import pandas as pd

# Data extracted from the SQL output
data = {'Year': [2018, 2019, 2020, 2021, 2022],
        'Sales': [15000, 18000, 16500, 22000, 25000]}

df = pd.DataFrame(data)
# Create the Plot
fig = px.line(df, x='Year', y='Sales', title='Annual Sales Trend')
# fig.show() # DO NOT show the figure yet, as it will be shown by the UI code
```
                '''),
            },
        },
        "required": [
            "code",
        ]
    }
)


plot_chart_func = FunctionDeclaration(
    name="plot_chart",
    description="get the data to plot chart in python plotly."
    " Also deteremines the best suit plot-type for that data.",
    parameters={
        "type": "object",
        "properties": {
            "data": {
                "type": "string",
                "description": textwrap.dedent('''
                    The data output from SQL query in a standard json
                    format which must be directly convertible to
                    a pandas dataframe. It must not have any extra keys
                    like "content".
                    example data: ```"{'month': ['Jan', 'Feb', 'Mar'],
                    'count': [62, 64, 20]
                    }"```
                '''),
            },
            "plot_type": {
                "type": "string",
                "enum": ["bar", "line", "pie", "scatter",],
                "description": "The type of plot to be generated best fit\
                      for understanding the data. e.g.\
                      bar, line, pie, scatter",
            },
            "title": {
                "type": "string",
                "description": "The title of the plot.",
            },
            "x_axis": {
                "type": "string",
                "description": "The column to be used for the x-axis\
                      of the plot.",
            },
            "y_axis": {
                "type": "string",
                "description": "The column to be used for the y-axis\
                      of the plot.",
            },
            # "xlabel": {
            #     "type": "string",
            #     "description": "The label for the x-axis.",
            # },
            # "ylabel": {
            #     "type": "string",
            #     "description": "The label for the y-axis.",
            # },
            # "legend": {
            #     "type": "string",
            #     "description": "The legend for the plot.",
            # },
            # "color": {
            #     "type": "string",
            #     "description": "The color of the plot.",
            # },
            # "size": {
            #     "type": "string",
            #     "description": "The size of the plot.",
            # },
            # "alpha": {
            #     "type": "string",
            #     "description": "The alpha of the plot.",
            # },
            # "marker": {
            #     "type": "string",
            #     "description": "The marker of the plot.",
            # },
            # "linestyle": {
            #     "type": "string",
            #     "description": "The linestyle of the plot.",
            # },
            # "linewidth": {
            #     "type": "string",
            #     "description": "The linewidth of the plot.",
            # },
            # "grid": {
            #     "type": "string",
            #     "description": "The grid of the plot.",
            # },
            # "xticks": {
            #     "type": "string",
            #     "description": "The xticks of the plot.",
            # },
            # "yticks": {
            #     "type": "string",
            #     "description": "The yticks of the plot.",
            # },
            # "xlim": {
            #     "type": "string",
            #     "description": "The xlim of the plot.",
            # },
            # "ylim": {
            #     "type": ".",
            # }
        },
        "required": [
            "data",
            "plot_type",
        ]
    }
)
