nl2sql_library/nl2sql/assets/prompts/__init__.py (198 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ provides prompts as PromptTemplate objects """ import json import pkgutil from typing import List from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate from loguru import logger from pydantic.v1 import BaseModel from nl2sql.assets.examples import FewShot as FewShotExamples class _ZeroShot: # pylint: disable=invalid-name, missing-function-docstring def _create_template(self, filename: str) -> PromptTemplate: raw_data = pkgutil.get_data(__name__, filename) if raw_data is None: raise ValueError(f"{filename} cannot be read") template = json.loads(raw_data) return PromptTemplate.from_template( template="".join(template["template"]), template_format=template["template_format"], ) @property def COLUMN_DESCRIPTION_V1(self) -> PromptTemplate: return self._create_template("column_description_v1.json") @property def COLUMN_DESCRIPTION_V2(self) -> PromptTemplate: return self._create_template("column_description_v2.json") @property def TABLE_DESCRIPTION_V1(self) -> PromptTemplate: return self._create_template("table_description_v1.json") @property def TABLE_DESCRIPTION_V2(self) -> PromptTemplate: return self._create_template("table_description_v2.json") @property def TABLE_DESCRIPTION_V3(self) -> PromptTemplate: return self._create_template("table_description_v3.json") @property def TABLE_FILTER_THOUGHT_GEN_V1(self) -> PromptTemplate: return self._create_template("table_filter_thought_gen_v1.json") @property def TABLE_FILTER_THOUGHT_SCORE_V1(self) -> PromptTemplate: return self._create_template("table_filter_thought_score_v1.json") @property def TABLE_FILTER_THOUGHT_SCORE_V2(self) -> PromptTemplate: return self._create_template("table_filter_thought_score_v2.json") @property def COLUMN_FILTER_THOUGHT_GEN_V1(self) -> PromptTemplate: return self._create_template("column_filter_thought_gen_v1.json") @property def COLUMN_FILTER_THOUGHT_SCORE_V1(self) -> PromptTemplate: return self._create_template("column_filter_thought_score_v1.json") @property def PROMPTING_STRAT_QUERY_RANK_V1(self) -> PromptTemplate: return self._create_template("prompting_strat_query_rank_v1.json") @property def PROMPTING_STRAT_SQL_GEN_V1(self) -> PromptTemplate: return self._create_template("prompting_strat_sql_gen_v1.json") @property def PROMPTING_STRAT_TABLE_FILTER_GEN_V1(self) -> PromptTemplate: return self._create_template("prompting_strat_table_filter_gen_v1.json") @property def PROMPTING_STRAT_COLUMN_FILTER_GEN_V1(self) -> PromptTemplate: return self._create_template("prompting_strat_column_filter_gen_v1.json") @property def TASK_TABLE_SELECTION_CORE_V1(self) -> PromptTemplate: return self._create_template("task_table_selection_core_v1.json") @property def TASK_COLUMN_SELECTION_CORE_V1(self) -> PromptTemplate: return self._create_template("task_column_selection_core_v1.json") @property def TASK_JOIN_SELECTION_CORE_V1(self) -> PromptTemplate: return self._create_template("task_join_selection_core_v1.json") @property def TASK_SQL_GENERATION_CORE_V1(self) -> PromptTemplate: return self._create_template("task_sql_generation_core_v1.json") @property def TASK_EVAL_FIX_CORE_V1(self) -> PromptTemplate: return self._create_template("task_eval_fix_core_v1.json") ZeroShot = _ZeroShot() class _FewShot: # pylint: disable=invalid-name, missing-function-docstring def _create_template_v1( self, example_prompt: PromptTemplate, input_variables: List[str], examples: List[dict], num_examples: int, ) -> FewShotPromptTemplate: return FewShotPromptTemplate( example_prompt=example_prompt, input_variables=input_variables, suffix=example_prompt.template, examples=examples[:num_examples], template_format=example_prompt.template_format, ) def _create_template_v2( self, example_prompt: PromptTemplate, examples: List[dict], num_examples: int, ) -> FewShotPromptTemplate: from nl2sql.datasets import ( # pylint: disable=import-outside-toplevel fetch_dataset, ) # Importing this globally would make the entire prompt module dependent # on Datasets, while Datasets already depends on prompts, resulting in # a cyclic import error. Importing this here makes only _FewShotPrompts # dependednt on Datasets, while dataset does not depend on _FewShotPrompts dataset_map = {j: fetch_dataset(j) for j in {i["dataset"] for i in examples}} extended_examples = [] for e in examples: db_descriptor = { db.name: db.descriptor for db in ( dataset_map[e["dataset"]] .filter(filters=e["data_id"], filter_type="only") .databases.values() ) if db.descriptor } if db_descriptor: extended_examples.append( { **e, "db_descriptor": db_descriptor, } ) if len(extended_examples) >= num_examples: break else: raise ValueError(f"Unable to load {num_examples} Examples") return FewShotPromptTemplate( example_prompt=example_prompt, input_variables=example_prompt.input_variables, suffix=example_prompt.template, examples=extended_examples, template_format=example_prompt.template_format, ) @property def PROMPTING_STRAT_FEW_SHOT_SQL_GEN_V1(self) -> FewShotPromptTemplate: logger.info("Instantiating PROMPTING_STRAT_FEW_SHOT_SQL_GEN_V1") return self._create_template_v1( ZeroShot.PROMPTING_STRAT_SQL_GEN_V1, input_variables=["context", "question", "query"], examples=FewShotExamples.EXAMPLES_SPIDER_SQL_QUERIES_V1, num_examples=5, ) @property def PROMPTING_STRAT_FEW_SHOT_TABLE_FILTER_GEN_V1(self) -> FewShotPromptTemplate: logger.info("Instantiating PROMPTING_STRAT_FEW_SHOT_TABLE_FILTER_GEN_V1") return self._create_template_v1( ZeroShot.PROMPTING_STRAT_TABLE_FILTER_GEN_V1, input_variables=["context", "question", "thoughts"], examples=FewShotExamples.EXAMPLES_TABLE_FILTER_V1, num_examples=5, ) @property def PROMPTING_STRAT_FEW_SHOT_COLUMN_FILTER_GEN_V1(self) -> FewShotPromptTemplate: logger.info("Instantiating PROMPTING_STRAT_FEW_SHOT_COLUMN_FILTER_GEN_V1") return self._create_template_v1( ZeroShot.PROMPTING_STRAT_COLUMN_FILTER_GEN_V1, input_variables=["context", "question", "thoughts", "answer"], examples=FewShotExamples.EXAMPLES_COLUMN_FILTER_V1, num_examples=5, ) @property def SPIDER_FEW_SHOT_COLUMN_FILTER_V1(self) -> FewShotPromptTemplate: logger.info("Instantiating SPIDER_FEW_SHOT_COLUMN_FILTER_V1") return self._create_template_v2( ZeroShot.PROMPTING_STRAT_COLUMN_FILTER_GEN_V1, examples=FewShotExamples.EXAMPLES_COLUMN_FILTER_V2, num_examples=5, ) @property def TASK_TABLE_SELECTION_CORE_V1_SPIDER_V1(self) -> FewShotPromptTemplate: logger.info("Instantiating TASK_TABLE_SELECTION_CORE_V1_SPIDER_V1") return self._create_template_v2( ZeroShot.TASK_TABLE_SELECTION_CORE_V1, examples=FewShotExamples.EXAMPLES_TABLE_FILTER_V2, num_examples=5, ) @property def TASK_COLUMN_SELECTION_CORE_V1_SPIDER_V1(self) -> FewShotPromptTemplate: logger.info("Instantiating TASK_COLUMN_SELECTION_CORE_V1_SPIDER_V1") return self._create_template_v2( ZeroShot.TASK_COLUMN_SELECTION_CORE_V1, examples=FewShotExamples.EXAMPLES_COLUMN_FILTER_V2, num_examples=5, ) @property def TASK_JOIN_SELECTION_CORE_V1_SPIDER_V1(self) -> FewShotPromptTemplate: logger.info("Instantiating TASK_JOIN_SELECTION_CORE_V1_SPIDER_V1") return self._create_template_v2( ZeroShot.TASK_JOIN_SELECTION_CORE_V1, examples=FewShotExamples.EXAMPLES_JOIN_IDENTIFICATION_V2, num_examples=5, ) @property def TASK_SQL_GENERATION_CORE_V1_SPIDER_V1(self) -> FewShotPromptTemplate: logger.info("Instantiating TASK_SQL_GENERATION_CORE_V1_SPIDER_V1") return self._create_template_v2( ZeroShot.TASK_SQL_GENERATION_CORE_V1, examples=FewShotExamples.EXAMPLES_SQL_GENERATION_V1, num_examples=2, ) FewShot = _FewShot()