nl2sql_library/sample_executors/cot_executor.py (57 lines of code) (raw):
"""
Chain of Thought Executor Sample
"""
import json
# import os
import sys
from os.path import dirname, abspath
from loguru import logger
from nl2sql.llms.vertexai import text_bison_32k
from nl2sql.executors.linear_executor.core import CoreLinearExecutor
from nl2sql.tasks.table_selection.core import CoreTableSelector
from nl2sql.tasks.table_selection.core import prompts as cts_prompts
from nl2sql.tasks.column_selection.core import (
CoreColumnSelector,
prompts as ccs_prompts,
)
from nl2sql.tasks.sql_generation.core import CoreSqlGenerator
from nl2sql.tasks.sql_generation.core import prompts as csg_prompts
sys.path.insert(1, dirname(dirname(abspath(__file__))))
dataset_name = "nl2sql_spider" # @param {type:"string"}
f = open("../utils/spider_md_cache.json", encoding="utf-8")
spider_data = json.load(f)
data_dictionary_read = {
"nl2sql_spider": {
"description": "This dataset contains information about the concerts\
singers, country they belong to, stadiums where the \
concerts happened",
"tables": spider_data,
},
}
llm = text_bison_32k()
# Disabling logs because these steps generate a LOT of logs.
logger.disable("nl2sql.datasets.base")
core_table_selector = CoreTableSelector(
llm=llm, prompt=cts_prompts.CURATED_FEW_SHOT_COT_PROMPT
)
core_column_selector = CoreColumnSelector(
llm=llm, prompt=ccs_prompts.CURATED_FEW_SHOT_COT_PROMPT
)
core_sql_generator = CoreSqlGenerator(
llm=llm, prompt=csg_prompts.CURATED_FEW_SHOT_COT_PROMPT
)
logger.enable("nl2sql.datasets.base")
bigquery_connection_string = "bigquery://sl-test-project-363109/nl2sql_spider"
dataset_name = "nl2sql_spider"
dd_cot_executor = CoreLinearExecutor.from_connection_string_map(
{dataset_name: bigquery_connection_string},
core_table_selector=core_table_selector,
core_column_selector=core_column_selector,
core_sql_generator=core_sql_generator,
data_dictionary=data_dictionary_read,
)
print("\n\n", "=" * 25, "Executor Created", "=" * 25, "\n\n")
print("Executor ID :", dd_cot_executor.executor_id)
# # Now run the executor with a sample question
dd_cot_result = dd_cot_executor(
db_name=dataset_name,
question="What is the average, minimum, and maximum age \
of all singers from France ?", # @param {type:"string"}
)
print("\n\n", "=" * 50, "Generated SQL", "=" * 50, "\n\n")
print("Result ID:", dd_cot_result.result_id, "\n\n")
print(dd_cot_result.generated_query)