nl2sql_library/nl2sql/datasets/base.py (535 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implements core Dataset functionality.
A dataset is a group of databases, intended to represent a data warehouse
containing multiple databases.
"""
import re
import typing
import numpy as np
import pandas as pd
from langchain.prompts import PromptTemplate
from langchain.sql_database import SQLDatabase
from loguru import logger
from pydantic import (
BaseModel,
ConfigDict,
SkipValidation,
field_serializer,
field_validator,
)
from pydantic.networks import UrlConstraints
from pydantic_core import Url
from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine
from sqlalchemy.sql import expression as sqe
from sqlalchemy.sql.ddl import CreateTable
from sqlalchemy.sql.functions import func
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.sqltypes import VARCHAR
from typing_extensions import Self, TypedDict
from nl2sql.assets.prompts import ZeroShot as ZeroShotPrompts
ColName = typing.TypeVar("ColName", bound=str)
ColType = typing.TypeVar("ColType", bound=str)
ColDesc = typing.TypeVar("ColDesc", bound=str)
TabName = typing.TypeVar("TabName", bound=str)
TabDesc = typing.TypeVar("TabDesc", bound=str)
DBName = typing.TypeVar("DBName", bound=str)
DBDesc = typing.TypeVar("DBDesc", bound=str)
EnumValues = dict[DBName, dict[ColName, list[str]]]
ColDataDictionary = TypedDict(
"ColDataDictionary", {"description": ColDesc, "type": ColType}
)
BaseColDataDictionary = dict[ColName, ColDataDictionary]
TabDataDictionary = TypedDict(
"TabDataDictionary",
{"description": TabDesc, "columns": BaseColDataDictionary},
)
BaseTabDataDictionary = dict[TabName, TabDataDictionary]
DatabaseDataDictionary = TypedDict(
"DatabaseDataDictionary",
{"description": DBDesc, "tables": BaseTabDataDictionary},
)
DatasetDataDictionary = dict[DBName, DatabaseDataDictionary]
BaseTableSchema = dict[ColName, ColType]
BaseDatabaseSchema = dict[TabName, BaseTableSchema]
BaseDatasetSchema = dict[DBName, BaseDatabaseSchema]
DatabaseSchema = TypedDict(
"DatabaseSchema", {"metadata": MetaData, "tables": BaseDatabaseSchema}
)
DatasetSchema = dict[DBName, DatabaseSchema]
BaseColDescriptor = TypedDict(
"BaseColDescriptor",
{
"col_type": str,
"col_nullable": bool,
"col_pk": bool,
"col_defval": typing.Any | None,
"col_comment": str | None,
"col_enum_vals": list[str] | None,
"col_description": str | None,
},
)
BaseTabDescriptor = TypedDict(
"BaseTabDescriptor",
{
"table_name": TabName,
"table_creation_statement": str,
"table_sample_rows": str,
"col_descriptor": dict[str, BaseColDescriptor],
},
)
TableDescriptor = dict[TabName, BaseTabDescriptor]
AllowedDSN = typing.Annotated[
Url,
UrlConstraints(
host_required=False,
allowed_schemes=[
"postgres",
"postgresql",
"mysql",
"bigquery",
"sqlite"
],
),
]
class EntitySet(BaseModel):
"""
Expects a list of identifiers in the form of
databasename.tablename.columnname
"""
ids: list[str]
dataset_schema: BaseDatasetSchema
@field_validator("ids")
@classmethod
def id_structure(cls, ids: list[str]) -> list[str]:
"""
Validates incoming IDs
"""
for curr_id in ids:
assert curr_id != "*.*.*", '"*.*.*" is not allowed'
assert (
len(id_parts := curr_id.split(".")) == 3
), f"Malformed Entity ID {curr_id}"
for id_part, part_type in zip(
id_parts,
["database",
"table",
"column"]
):
assert id_part == "*" or re.match(
"^[a-zA-Z0-9_-]+$", id_part
), f"Malformed {part_type} '{id_part}' in {curr_id}"
return ids
def __hash__(self) -> int:
"""
Provides hash for the object
"""
return ",".join(sorted(self.ids)).__hash__()
def filter(
self, key: typing.Literal["database", "table", "column"], value: str
) -> "EntitySet":
"""
Generates a new entityset by filtering the current dataset
based on the key and values passed
"""
new_ids = []
for curr_id in self.ids:
db_name, tab_name, col_name = curr_id.split(".")
if {"database": db_name, "table": tab_name, "column": col_name}[
key
] == value:
new_ids.append(curr_id)
return EntitySet(
ids=new_ids,
dataset_schema=self.dataset_schema,
)
def invert(self) -> "EntitySet":
"""
Returnes the complement of the provided keys based on the schema.
"""
return EntitySet(
ids=list(
{
f"{db}.{tab}.{col}"
for db, tabval in self.dataset_schema.items()
for tab, colval in tabval.items()
for col in colval.keys()
}
- set(self.ids)
),
dataset_schema=self.dataset_schema,
)
def prune_schema(self) -> BaseDatasetSchema:
"""
Reduces the schema to only contain the keys present in the provided IDs
"""
schema: BaseDatasetSchema = {}
for curr_id in self.ids:
dbname, tabname, colname = curr_id.split(".")
if dbname not in schema:
schema[dbname] = {}
if tabname not in schema[dbname]:
schema[dbname][tabname] = {}
if colname not in schema[dbname][tabname]:
schema[dbname][tabname][colname] =\
self.dataset_schema[dbname][tabname][
colname
]
return schema
def model_post_init(self, __context: object) -> None:
stack = list(set(self.ids))
resolved_ids: list[str] = []
while stack:
curr_id = stack.pop()
database, table, column = curr_id.split(".")
if database == "*":
stack.extend(
[f"{db}.{table}.{column}"
for db in self.dataset_schema.keys()]
)
elif table == "*":
stack.extend(
[
f"{database}.{tab}.{column}"
for tab in self.dataset_schema.get(database, {}).keys()
]
)
elif column == "*":
stack.extend(
[
f"{database}.{table}.{col}"
for col in self.dataset_schema.get(database, {})
.get(table, {})
.keys()
]
)
elif (
(database in self.dataset_schema.keys())
and (table in self.dataset_schema[database].keys())
and (column in self.dataset_schema[database][table].keys())
):
resolved_ids.append(curr_id)
else:
logger.debug(
f"Invalid Filter Expression Found: {curr_id}. Skipping."
)
self.ids = resolved_ids
class Database(BaseModel):
"""
Implements the core Database class which provides various
utilities for a DB
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
name: str
db: SQLDatabase
dsn: AllowedDSN
dbschema: BaseDatabaseSchema
enum_limit: int = 10
descriptor: TableDescriptor = {}
exclude_entities: EntitySet = EntitySet(ids=[], dataset_schema={})
data_dictionary: DatabaseDataDictionary | None = None
table_desc_template: SkipValidation[PromptTemplate] = (
ZeroShotPrompts.TABLE_DESCRIPTION_V3
)
# TODO @madhups - find and remove all uses of SkipValidation across all
# modules after Langchain has been ported to use Pydantic V2:
# https://github.com/langchain-ai/langchain/discussions/9337
@field_serializer("table_desc_template")
def serialize_prompt_template(self,
table_desc_template: PromptTemplate,
_info
):
"""
Langchain Serializer
"""
return {
"template": table_desc_template.template,
"template_format": table_desc_template.template_format,
}
@field_serializer("db")
def serialize_db_basic(self, db: SQLDatabase, _info):
# pylint: disable=protected-access
"""
Langchain Serializer
"""
return {
"engine.url": db._engine.url.render_as_string(),
"_all_tables": db._all_tables,
"_usable_tables": db._usable_tables,
"_sample_rows_in_table_info": db._sample_rows_in_table_info,
"_indexes_in_table_info": db._indexes_in_table_info,
"_custom_table_info": db._custom_table_info,
"_max_string_length": db._max_string_length,
}
@classmethod
def fetch_schema(cls, name: str, dsn: AllowedDSN) -> BaseDatasetSchema:
"""
Queries the dtabase to find out the schema for a given db name and DSN
"""
logger.info(f"[{name}] : Fetching Schema ...")
metadata = MetaData()
metadata.reflect(
bind=create_engine(dsn.unicode_string()),
views=True,
)
db_schema: BaseDatasetSchema = {name: {}}
if metadata.tables:
for tablename, table in metadata.tables.items():
tabledata = {}
for column in table.columns:
tabledata[column.name] = str(column.type)
if not tabledata:
msg = f"No columns found in {name}.{tablename}"
logger.critical(msg)
raise AttributeError(msg)
db_schema[name][tablename] = tabledata
if not db_schema[name]:
msg = f"No tables found in {name}"
logger.critical(msg)
raise AttributeError(msg)
logger.success(f"[{name}] : Schema Obtained Successfully")
return db_schema
@classmethod
def from_connection_string(
cls,
name: str,
connection_string: str,
schema: BaseDatasetSchema | None = None,
**kwargs,
) -> Self:
"""
Utility function to create a database from a name and connection_string
"""
logger.debug(f"[{name}] : Analysing ...")
dsn = AllowedDSN(connection_string)
schema = schema or cls.fetch_schema(name=name, dsn=dsn)
engine = create_engine(connection_string)
assert isinstance(engine, Engine)
if ("exclude_entities" in kwargs) and (
not isinstance(kwargs["exclude_entities"], EntitySet)
):
kwargs["exclude_entities"] = [
EntitySet(ids=list(
kwargs["exclude_entities"]),
dataset_schema=schema
)
]
db = SQLDatabase(
engine=engine,
view_support=True,
)
logger.success(f"[{name}] : Analysis Complete")
return cls(
name=name,
dsn=dsn,
dbschema=schema[name],
db=db,
**kwargs,
)
def filter(
self,
filters: list[str], filter_type: typing.Literal["only", "exclude"]
) -> "Database":
"""
Returns a new database object after applying the provided filters
"""
entities = EntitySet(ids=filters,
dataset_schema={self.name: self.dbschema})
if filter_type == "only":
entities = entities.invert()
return Database(
name=self.name,
db=self.db,
dsn=self.dsn,
dbschema=self.dbschema,
enum_limit=self.enum_limit,
descriptor=self.descriptor,
exclude_entities=entities,
data_dictionary=self.data_dictionary,
table_desc_template=self.table_desc_template,
)
def execute(self, query: str) -> pd.DataFrame:
"""
Returns the results of a query as a Pandas DataFrame
"""
return pd.read_sql(sql=query, con=self.db._engine)
def model_post_init(self, __context: object) -> None:
# pylint: disable=protected-access, too-many-branches
"""
Langchain's Post-Init method to properly validate DB
"""
logger.debug(f"[{self.name}] : Instantiating ...")
logger.debug(f"[{self.name}] : Calculating Exclusions ...")
table_exclusions = []
all_exclusions = set(self.exclude_entities.ids)
for tablename, tableinfo in self.dbschema.items():
if {
f"{self.name}.{tablename}.{column}"
for column in tableinfo.keys()
}.issubset(all_exclusions):
table_exclusions.append(tablename)
if table_exclusions:
logger.info(
f"[{self.name}] : These tables will be excluded :"
+ (", ".join(table_exclusions))
)
else:
logger.info(f"[{self.name}] : No tables will be excluded")
logger.success(f"[{self.name}] : Exclusions Calculated")
logger.debug(f"[{self.name}] : Generating Custom Descriptions ...")
engine = create_engine(self.dsn.unicode_string())
assert isinstance(engine, Engine)
temp_db = SQLDatabase(
engine=engine,
ignore_tables=table_exclusions,
view_support=True,
)
table_descriptor: dict[str, BaseTabDescriptor] = {}
table_descriptions = {}
for table in temp_db._metadata.sorted_tables:
if table.name in table_exclusions:
continue
constraints = {
col.name
for con in table.constraints
for col in con.columns # type: ignore
}
col_descriptor: dict[str, BaseColDescriptor] = {}
col_enums = []
for col in table._columns: # type: ignore
if (col.name not in constraints) and (
f"{self.name}.{table.name}.{col.name}" in all_exclusions
):
logger.info(
f"[{self.name}.{table.name}] :\
Removing column {col.name}"
)
table._columns.remove(col) # type: ignore
else:
if (table.name not in self.descriptor) or (
col.name
not in self.descriptor[table.name]["col_descriptor"]
):
if (
self.enum_limit > 0
) and (
col.type.python_type == str
):
col_enums.append(
sqe.select(
sqe.literal(col.name, VARCHAR).label(
"COLNAME"
),
sqe.case(
(
sqe.select(
func.count(sqe.distinct(col))
< self.enum_limit
).label("COLCOUNT"),
col,
)
).label("COLVALS"),
).distinct()
)
col_descriptor_map: BaseColDescriptor = {
"col_type": str(col.type),
"col_nullable": col.nullable,
"col_pk": col.primary_key,
"col_defval": col.default,
"col_comment": col.comment,
"col_enum_vals": None,
"col_description": (
(
self.data_dictionary["tables"][table.name][
"columns"
][col.name]["description"]
)
if (
(self.data_dictionary)
and (
table.name in self.data_dictionary[
"tables"
]
)
and (
col.name
in self.data_dictionary[
"tables"
][table.name][
"columns"
]
)
)
else None
),
}
else:
col_descriptor_map = self.descriptor[table.name][
"col_descriptor"
][col.name]
col_descriptor[col.name] = col_descriptor_map
for colname, colvals in (
(
pd.read_sql(sql=sqe.union(*col_enums), con=engine)
.replace("", np.nan)
.dropna()
.groupby("COLNAME", group_keys=False)["COLVALS"]
.apply(list)
.to_dict()
)
if col_enums
else {}
).items():
col_descriptor[colname]["col_enum_vals"] = colvals
table_descriptor[table.name] = {
"table_name": table.name,
"table_creation_statement": str(
CreateTable(table).compile(engine)
).rstrip(),
"table_sample_rows": temp_db._get_sample_rows(table),
"col_descriptor": col_descriptor,
}
logger.trace(
f"[{self.name}] : Table descriptor created for {table.name}"
+ f"\n{table_descriptor[table.name]}"
)
table_descriptions[table.name] = self.table_desc_template.format(
**{
key: value
for key, value in table_descriptor[table.name].items()
if key in self.table_desc_template.input_variables
}
)
self.descriptor = table_descriptor
logger.success(f"[{self.name}] : Custom Descriptions Generated")
temp_db._custom_table_info = table_descriptions
self.db = temp_db
logger.success(f"[{self.name}] : Instantiated")
DBNameDBMap = dict[DBName, Database]
class Dataset(BaseModel):
"""
A dataset is a collection of databases
"""
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
databases: DBNameDBMap
dataset_schema: BaseDatasetSchema
exclude_entities: EntitySet = EntitySet(ids=[], dataset_schema={})
data_dictionary: DatasetDataDictionary = {}
enum_limit: int = 10
table_desc_template: SkipValidation[PromptTemplate] = (
ZeroShotPrompts.TABLE_DESCRIPTION_V3
)
# TODO @madhups - remove SkipValidation after Pydantic V2 support is added
# to Langchain: https://github.com/langchain-ai/langchain/discussions/9337
def get_database(self, database_name: str) -> Database:
"""
Utility function to fetch a specific database from a dataset
"""
return self.databases[database_name]
def filter(
self,
filters: list[str],
filter_type: typing.Literal["only", "exclude"],
prune: bool = False,
) -> "Dataset":
"""
Applies a filter to this dtaset and provides a new instance.
"""
databases = {
k: v.filter(
filters,
filter_type
) for k, v in self.databases.items()
}
if prune:
databases = {k: v for k, v in databases if v.db.table_info}
return Dataset(
databases=databases,
dataset_schema=self.dataset_schema,
exclude_entities=self.exclude_entities,
data_dictionary=self.data_dictionary,
enum_limit=self.enum_limit,
table_desc_template=self.table_desc_template,
)
@property
def list_databases(self) -> list[str]:
"""
Returns a list of databases in this dataset
"""
return list(self.databases.keys())
@field_serializer("table_desc_template")
def serialize_prompt_template(self,
table_desc_template: PromptTemplate,
_info
):
"""
Langchain Serializer
"""
return {
"template": table_desc_template.template,
"template_format": table_desc_template.template_format,
}
@classmethod
def from_connection_strings(
cls,
name_connstr_map: dict[str, str],
exclude_entities: list[str] = [],
**kwargs,
) -> Self:
"""
Utility function to create a dataset from a name -> conn_str mapping.
"""
dataset_schema = {
db_name: Database.fetch_schema(db_name,
AllowedDSN(db_connstr))[db_name]
for db_name, db_connstr in name_connstr_map.items()
}
parsed_exclude_entities = EntitySet(
ids=exclude_entities, dataset_schema=dataset_schema
)
data_dictionary = kwargs.pop("data_dictionary", dict())
databases = {
db_name: Database.from_connection_string(
name=db_name,
connection_string=db_connstr,
exclude_entities=parsed_exclude_entities.filter(
key="database", value=db_name
),
schema={db_name: dataset_schema[db_name]},
data_dictionary=data_dictionary.get(db_name),
**kwargs,
)
for db_name, db_connstr in name_connstr_map.items()
}
return cls(
databases=databases,
dataset_schema=dataset_schema,
exclude_entities=parsed_exclude_entities,
data_dictionary=data_dictionary,
**kwargs,
)