in nl2sql_library/nl2sql/tasks/table_selection/core.py [0:0]
def __call__(self, db: Database, question: str) -> CoreTableSelectorResult:
"""
Runs the Table Selection pipeline
"""
logger.info(f"Running {self.tasktype} ...")
selected_tables = []
intermediate_steps = []
if self.prompt.call_for_each_table:
targets = {
tablename: {db.name: {tablename: tabledescriptor}}
for tablename, tabledescriptor in db.descriptor.items()
}
else:
targets = {",".join(db.db._usable_tables): {db.name: db.descriptor}}
for tablename, dbdescriptor in targets.items():
prompt_params = {
"question": question,
"query": question,
"thoughts": [],
"answer": None,
"db_descriptor": dbdescriptor,
"table_name": tablename,
"table_names": list(db.db._usable_tables),
}
prepared_prompt = self.prompt.prompt_template.format(
**{
k: v
for k, v in prompt_params.items()
if k in self.prompt.prompt_template.input_variables
}
)
llm_response = self.llm.generate([prepared_prompt])
logger.debug(
f"[{self.tasktype}] : Received LLM Response : {llm_response.json()}"
)
try:
raw_response = llm_response.generations[0][0].text.strip()
except IndexError as exc:
raise ValueError(
f"Empty / Invalid Response received from LLM : {llm_response.json()}"
) from exc
parsed_response = (
self.prompt.parser.parse(raw_response)
if self.prompt.parser
else raw_response
)
processed_response = self.prompt.post_processor(parsed_response)
intermediate_steps.append(
{
"tasktype": self.tasktype,
"table": tablename,
"prepared_prompt": prepared_prompt,
"llm_response": llm_response.dict(),
"raw_response": raw_response,
"parsed_response": parsed_response,
"processed_response": processed_response,
}
)
if processed_response:
if self.prompt.call_for_each_table:
selected_tables.append(tablename)
else:
selected_tables = processed_response
filtered_selected_tables: set[str] = set.intersection(
set(selected_tables), db.db._usable_tables
)
if not filtered_selected_tables and self.prompt.greedy_post_processor:
logger.debug(f"[{self.tasktype}] : Running Greedy Post Processor...")
filtered_selected_tables = self.prompt.greedy_post_processor(
intermediate_steps, db.db._usable_tables
)
logger.trace(
f"[{self.tasktype}] : Tables selected after greedy filtering : {filtered_selected_tables}"
)
if not filtered_selected_tables:
logger.critical("No table Selected!")
return CoreTableSelectorResult(
db_name=db.name,
question=question,
available_tables=db.db._usable_tables,
selected_tables=filtered_selected_tables,
intermediate_steps=intermediate_steps,
)