nl2sql_library/nl2sql/tasks/sql_generation/react.py (77 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implementation of the ReAct prompting based approach to SQL Generation
"""
from typing import Any
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.llms.base import BaseLLM
from loguru import logger
from pydantic import SkipValidation
from typing_extensions import Literal
from nl2sql.datasets.base import Database
from nl2sql.tasks.sql_generation import BaseSqlGenerationResult, BaseSqlGenerationTask
class ReactSqlGenratorResult(BaseSqlGenerationResult):
"""
Implements ReAct SQL Generation Results
"""
resulttype: Literal[
"Result.SqlGeneration.ReactSqlGenerator"
] = "Result.SqlGeneration.ReactSqlGenerator"
class ReactSqlGenerator(BaseSqlGenerationTask):
"""
Implements ReAct SQL Generation Task
"""
tasktype: Literal[
"Task.SqlGeneration.ReactSqlGenerator"
] = "Task.SqlGeneration.ReactSqlGenerator"
llm: SkipValidation[BaseLLM]
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION
max_iterations: int | None = 15
def __call__(self, db: Database, question: str) -> ReactSqlGenratorResult:
"""
Runs the SQL Generation pipeline
"""
intermediate_steps: list[Any] = []
logger.info(f"Running {self.tasktype} ...")
agent = create_sql_agent(
llm=self.llm,
toolkit=SQLDatabaseToolkit(db=db.db, llm=self.llm),
agent_type=self.agent_type,
top_k=self.max_rows_limit,
max_iterations=self.max_iterations,
verbose=False,
early_stopping_method="generate",
)
agent.return_intermediate_steps = True
agent.handle_parsing_errors = True
try:
result = agent(question)
except Exception as exc: # pylint: disable=broad-exception-caught
intermediate_steps.append(f"Exception in Agent.run : {exc}")
query = None
else:
isteps = result.get("intermediate_steps")
intermediate_steps.extend(
[
{"input": step[0].to_json(), "output": step[1]}
if isinstance(step, tuple)
else step
for step in isteps
]
if isteps
else []
)
try:
query = next(
map(
lambda x: x[0]
.tool_input.replace(";", "")
.replace("sql```", "")
.replace("```sql", "")
.replace("```", ""),
filter(
lambda x: x[0].tool
in ["sql_db_query", "sql_db_query_checker"],
reversed(result.get("intermediate_steps", [])),
),
),
None,
)
except Exception as exc: # pylint: disable=broad-exception-caught
query = None
intermediate_steps.append(f"Exception in parsing query : {exc}")
return ReactSqlGenratorResult(
db_name=db.name,
question=question,
generated_query=query,
intermediate_steps=intermediate_steps,
)