nl2sql/tasks/join_selection/core.py (170 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implementation of the core prompting based approach to Join Selection
"""
from typing import Callable
from uuid import uuid4
from langchain.llms.base import BaseLLM
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate
from loguru import logger
from pydantic import BaseModel, SkipValidation
from typing_extensions import Literal
from nl2sql.assets.prompts import FewShot as FewShotPrompts
from nl2sql.assets.prompts import ZeroShot as ZeroShotPrompts
from nl2sql.datasets.base import Database
from nl2sql.tasks.join_selection import BaseJoinSelectionResult, BaseJoinSelectionTask
class _CoreJoinSelectorPrompt(BaseModel):
"""
A Wrapper around Join Selector Prompts
"""
prompt_id: str
prompt_template: SkipValidation[BasePromptTemplate]
parser: SkipValidation[StructuredOutputParser] | None = None
post_processor: Callable
class _JoinSelectorPrompts:
# pylint: disable=missing-function-docstring, invalid-name
"""
Provides prompt options for selecting Joins before generating SQL
"""
default_parser = StructuredOutputParser.from_response_schemas(
[
ResponseSchema(
name="thoughts",
description=(
"A short analysis of the question and available tables and "
"columns, demonstrating which joins would help in answering"
" the question and why. If no joins are needed, explain "
"why."
),
),
ResponseSchema(
name="joins",
description=(
"A comma separated list of joining conditions that would "
"help in answering the question in the format "
"tablename1.columnname1=tablename2.columnname2. If no "
"joins are needed, this field should be null."
),
),
]
)
@property
def CURATED_ZERO_SHOT_PROMPT(self) -> _CoreJoinSelectorPrompt:
prompt_template = ZeroShotPrompts.TASK_JOIN_SELECTION_CORE_V1.partial(
format_instructions=self.default_parser.get_format_instructions()
)
return _CoreJoinSelectorPrompt(
prompt_id="TASK_JOIN_SELECTION_CORE_V1",
prompt_template=prompt_template,
parser=self.default_parser,
post_processor=lambda x: [
i.strip().replace(" ", "") for i in x["joins"].split(",")
]
if ((x) and (x.get("joins")))
else [],
)
@property
def CURATED_FEW_SHOT_COT_PROMPT(self) -> _CoreJoinSelectorPrompt:
prompt_template = FewShotPrompts.TASK_JOIN_SELECTION_CORE_V1_SPIDER_V1.partial(
format_instructions=self.default_parser.get_format_instructions()
)
prompt_template.example_prompt = prompt_template.example_prompt.partial( # type: ignore
format_instructions=self.default_parser.get_format_instructions()
)
return _CoreJoinSelectorPrompt(
prompt_id="TASK_JOIN_SELECTION_CORE_V1_SPIDER_V1",
prompt_template=prompt_template,
parser=self.default_parser,
post_processor=lambda x: [i.strip().replace(" ", "") for i in x["joins"]]
if (x and isinstance(x.get("joins"), list))
else [],
)
@classmethod
def custom_prompt(
cls,
prompt_template: BasePromptTemplate,
parser: StructuredOutputParser | None = None,
post_processor: Callable = lambda x: x,
prompt_template_id: str | None = None,
) -> _CoreJoinSelectorPrompt:
"""
Use a custom PromptTemplate for Join filtering.
"""
if not prompt_template_id:
prompt_template_id = uuid4().hex
if parser:
prompt_template = prompt_template.partial(
format_instructions=parser.get_format_instructions()
)
if hasattr(prompt_template, "example_prompt") and isinstance(
getattr(prompt_template, "example_prompt"), PromptTemplate
):
prompt_template.example_prompt = getattr(
prompt_template, "example_prompt"
).partial(format_instructions=parser.get_format_instructions())
return _CoreJoinSelectorPrompt(
prompt_id=f"CUSTOM-{prompt_template_id}",
prompt_template=prompt_template,
post_processor=post_processor,
parser=parser,
)
prompts = _JoinSelectorPrompts()
class CoreJoinSelectorResult(BaseJoinSelectionResult):
"""
Implements Core Join Selector Results
"""
resulttype: Literal[
"Result.JoinSelection.CoreJoinSelector"
] = "Result.JoinSelection.CoreJoinSelector"
class CoreJoinSelector(BaseJoinSelectionTask):
"""
Implements Core Join Selector Task
"""
tasktype: Literal[
"Task.JoinSelection.CoreJoinSelector"
] = "Task.JoinSelection.CoreJoinSelector"
llm: SkipValidation[BaseLLM]
prompt: SkipValidation[_CoreJoinSelectorPrompt] = prompts.CURATED_ZERO_SHOT_PROMPT
def __call__(self, db: Database, question: str) -> CoreJoinSelectorResult:
"""
Runs the Join Selection pipeline
"""
logger.info(f"Running {self.tasktype} ...")
prompt_params = {
"question": question,
"query": question,
"thoughts": [],
"answer": None,
"db_descriptor": {db.name: db.descriptor},
"table_name": ", ".join(db.db._usable_tables),
"table_names": list(db.db._usable_tables),
}
prepared_prompt = self.prompt.prompt_template.format(
**{
k: v
for k, v in prompt_params.items()
if k in self.prompt.prompt_template.input_variables
}
)
llm_response = self.llm.generate([prepared_prompt])
logger.debug(
f"[{self.tasktype}] : Received LLM Response : {llm_response.json()}"
)
try:
raw_response = llm_response.generations[0][0].text.strip()
except IndexError as exc:
raise ValueError(
f"Empty / Invalid Response received from LLM : {llm_response.json()}"
) from exc
parsed_response = (
self.prompt.parser.parse(raw_response)
if self.prompt.parser
else raw_response
)
processed_response = self.prompt.post_processor(parsed_response)
intermediate_steps = [
{
"tasktype": self.tasktype,
"prepared_prompt": prepared_prompt,
"llm_response": llm_response.dict(),
"raw_response": raw_response,
"parsed_response": parsed_response,
"processed_response": processed_response,
}
]
allowed_joins = {
f"{tname}.{fk.parent.name}={getattr(fk, '_colspec')}"
for tname, tobj in db.db._metadata.tables.items()
for fk in tobj.foreign_keys
if fk.parent is not None
}
allowed_joins_lower_map = {i.lower(): i for i in allowed_joins}
selected_joins = {allowed_joins_lower_map.get(i, i) for i in processed_response}
if not selected_joins:
logger.critical("No Join Selected!")
return CoreJoinSelectorResult(
db_name=db.name,
question=question,
allowed_joins=allowed_joins,
selected_joins=selected_joins,
intermediate_steps=intermediate_steps,
)