nl2sql_library/nl2sql/datasets/base.py (535 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Implements core Dataset functionality. A dataset is a group of databases, intended to represent a data warehouse containing multiple databases. """ import re import typing import numpy as np import pandas as pd from langchain.prompts import PromptTemplate from langchain.sql_database import SQLDatabase from loguru import logger from pydantic import ( BaseModel, ConfigDict, SkipValidation, field_serializer, field_validator, ) from pydantic.networks import UrlConstraints from pydantic_core import Url from sqlalchemy import create_engine from sqlalchemy.engine.base import Engine from sqlalchemy.sql import expression as sqe from sqlalchemy.sql.ddl import CreateTable from sqlalchemy.sql.functions import func from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import VARCHAR from typing_extensions import Self, TypedDict from nl2sql.assets.prompts import ZeroShot as ZeroShotPrompts ColName = typing.TypeVar("ColName", bound=str) ColType = typing.TypeVar("ColType", bound=str) ColDesc = typing.TypeVar("ColDesc", bound=str) TabName = typing.TypeVar("TabName", bound=str) TabDesc = typing.TypeVar("TabDesc", bound=str) DBName = typing.TypeVar("DBName", bound=str) DBDesc = typing.TypeVar("DBDesc", bound=str) EnumValues = dict[DBName, dict[ColName, list[str]]] ColDataDictionary = TypedDict( "ColDataDictionary", {"description": ColDesc, "type": ColType} ) BaseColDataDictionary = dict[ColName, ColDataDictionary] TabDataDictionary = TypedDict( "TabDataDictionary", {"description": TabDesc, "columns": BaseColDataDictionary}, ) BaseTabDataDictionary = dict[TabName, TabDataDictionary] DatabaseDataDictionary = TypedDict( "DatabaseDataDictionary", {"description": DBDesc, "tables": BaseTabDataDictionary}, ) DatasetDataDictionary = dict[DBName, DatabaseDataDictionary] BaseTableSchema = dict[ColName, ColType] BaseDatabaseSchema = dict[TabName, BaseTableSchema] BaseDatasetSchema = dict[DBName, BaseDatabaseSchema] DatabaseSchema = TypedDict( "DatabaseSchema", {"metadata": MetaData, "tables": BaseDatabaseSchema} ) DatasetSchema = dict[DBName, DatabaseSchema] BaseColDescriptor = TypedDict( "BaseColDescriptor", { "col_type": str, "col_nullable": bool, "col_pk": bool, "col_defval": typing.Any | None, "col_comment": str | None, "col_enum_vals": list[str] | None, "col_description": str | None, }, ) BaseTabDescriptor = TypedDict( "BaseTabDescriptor", { "table_name": TabName, "table_creation_statement": str, "table_sample_rows": str, "col_descriptor": dict[str, BaseColDescriptor], }, ) TableDescriptor = dict[TabName, BaseTabDescriptor] AllowedDSN = typing.Annotated[ Url, UrlConstraints( host_required=False, allowed_schemes=[ "postgres", "postgresql", "mysql", "bigquery", "sqlite" ], ), ] class EntitySet(BaseModel): """ Expects a list of identifiers in the form of databasename.tablename.columnname """ ids: list[str] dataset_schema: BaseDatasetSchema @field_validator("ids") @classmethod def id_structure(cls, ids: list[str]) -> list[str]: """ Validates incoming IDs """ for curr_id in ids: assert curr_id != "*.*.*", '"*.*.*" is not allowed' assert ( len(id_parts := curr_id.split(".")) == 3 ), f"Malformed Entity ID {curr_id}" for id_part, part_type in zip( id_parts, ["database", "table", "column"] ): assert id_part == "*" or re.match( "^[a-zA-Z0-9_-]+$", id_part ), f"Malformed {part_type} '{id_part}' in {curr_id}" return ids def __hash__(self) -> int: """ Provides hash for the object """ return ",".join(sorted(self.ids)).__hash__() def filter( self, key: typing.Literal["database", "table", "column"], value: str ) -> "EntitySet": """ Generates a new entityset by filtering the current dataset based on the key and values passed """ new_ids = [] for curr_id in self.ids: db_name, tab_name, col_name = curr_id.split(".") if {"database": db_name, "table": tab_name, "column": col_name}[ key ] == value: new_ids.append(curr_id) return EntitySet( ids=new_ids, dataset_schema=self.dataset_schema, ) def invert(self) -> "EntitySet": """ Returnes the complement of the provided keys based on the schema. """ return EntitySet( ids=list( { f"{db}.{tab}.{col}" for db, tabval in self.dataset_schema.items() for tab, colval in tabval.items() for col in colval.keys() } - set(self.ids) ), dataset_schema=self.dataset_schema, ) def prune_schema(self) -> BaseDatasetSchema: """ Reduces the schema to only contain the keys present in the provided IDs """ schema: BaseDatasetSchema = {} for curr_id in self.ids: dbname, tabname, colname = curr_id.split(".") if dbname not in schema: schema[dbname] = {} if tabname not in schema[dbname]: schema[dbname][tabname] = {} if colname not in schema[dbname][tabname]: schema[dbname][tabname][colname] =\ self.dataset_schema[dbname][tabname][ colname ] return schema def model_post_init(self, __context: object) -> None: stack = list(set(self.ids)) resolved_ids: list[str] = [] while stack: curr_id = stack.pop() database, table, column = curr_id.split(".") if database == "*": stack.extend( [f"{db}.{table}.{column}" for db in self.dataset_schema.keys()] ) elif table == "*": stack.extend( [ f"{database}.{tab}.{column}" for tab in self.dataset_schema.get(database, {}).keys() ] ) elif column == "*": stack.extend( [ f"{database}.{table}.{col}" for col in self.dataset_schema.get(database, {}) .get(table, {}) .keys() ] ) elif ( (database in self.dataset_schema.keys()) and (table in self.dataset_schema[database].keys()) and (column in self.dataset_schema[database][table].keys()) ): resolved_ids.append(curr_id) else: logger.debug( f"Invalid Filter Expression Found: {curr_id}. Skipping." ) self.ids = resolved_ids class Database(BaseModel): """ Implements the core Database class which provides various utilities for a DB """ model_config = ConfigDict(arbitrary_types_allowed=True) name: str db: SQLDatabase dsn: AllowedDSN dbschema: BaseDatabaseSchema enum_limit: int = 10 descriptor: TableDescriptor = {} exclude_entities: EntitySet = EntitySet(ids=[], dataset_schema={}) data_dictionary: DatabaseDataDictionary | None = None table_desc_template: SkipValidation[PromptTemplate] = ( ZeroShotPrompts.TABLE_DESCRIPTION_V3 ) # TODO @madhups - find and remove all uses of SkipValidation across all # modules after Langchain has been ported to use Pydantic V2: # https://github.com/langchain-ai/langchain/discussions/9337 @field_serializer("table_desc_template") def serialize_prompt_template(self, table_desc_template: PromptTemplate, _info ): """ Langchain Serializer """ return { "template": table_desc_template.template, "template_format": table_desc_template.template_format, } @field_serializer("db") def serialize_db_basic(self, db: SQLDatabase, _info): # pylint: disable=protected-access """ Langchain Serializer """ return { "engine.url": db._engine.url.render_as_string(), "_all_tables": db._all_tables, "_usable_tables": db._usable_tables, "_sample_rows_in_table_info": db._sample_rows_in_table_info, "_indexes_in_table_info": db._indexes_in_table_info, "_custom_table_info": db._custom_table_info, "_max_string_length": db._max_string_length, } @classmethod def fetch_schema(cls, name: str, dsn: AllowedDSN) -> BaseDatasetSchema: """ Queries the dtabase to find out the schema for a given db name and DSN """ logger.info(f"[{name}] : Fetching Schema ...") metadata = MetaData() metadata.reflect( bind=create_engine(dsn.unicode_string()), views=True, ) db_schema: BaseDatasetSchema = {name: {}} if metadata.tables: for tablename, table in metadata.tables.items(): tabledata = {} for column in table.columns: tabledata[column.name] = str(column.type) if not tabledata: msg = f"No columns found in {name}.{tablename}" logger.critical(msg) raise AttributeError(msg) db_schema[name][tablename] = tabledata if not db_schema[name]: msg = f"No tables found in {name}" logger.critical(msg) raise AttributeError(msg) logger.success(f"[{name}] : Schema Obtained Successfully") return db_schema @classmethod def from_connection_string( cls, name: str, connection_string: str, schema: BaseDatasetSchema | None = None, **kwargs, ) -> Self: """ Utility function to create a database from a name and connection_string """ logger.debug(f"[{name}] : Analysing ...") dsn = AllowedDSN(connection_string) schema = schema or cls.fetch_schema(name=name, dsn=dsn) engine = create_engine(connection_string) assert isinstance(engine, Engine) if ("exclude_entities" in kwargs) and ( not isinstance(kwargs["exclude_entities"], EntitySet) ): kwargs["exclude_entities"] = [ EntitySet(ids=list( kwargs["exclude_entities"]), dataset_schema=schema ) ] db = SQLDatabase( engine=engine, view_support=True, ) logger.success(f"[{name}] : Analysis Complete") return cls( name=name, dsn=dsn, dbschema=schema[name], db=db, **kwargs, ) def filter( self, filters: list[str], filter_type: typing.Literal["only", "exclude"] ) -> "Database": """ Returns a new database object after applying the provided filters """ entities = EntitySet(ids=filters, dataset_schema={self.name: self.dbschema}) if filter_type == "only": entities = entities.invert() return Database( name=self.name, db=self.db, dsn=self.dsn, dbschema=self.dbschema, enum_limit=self.enum_limit, descriptor=self.descriptor, exclude_entities=entities, data_dictionary=self.data_dictionary, table_desc_template=self.table_desc_template, ) def execute(self, query: str) -> pd.DataFrame: """ Returns the results of a query as a Pandas DataFrame """ return pd.read_sql(sql=query, con=self.db._engine) def model_post_init(self, __context: object) -> None: # pylint: disable=protected-access, too-many-branches """ Langchain's Post-Init method to properly validate DB """ logger.debug(f"[{self.name}] : Instantiating ...") logger.debug(f"[{self.name}] : Calculating Exclusions ...") table_exclusions = [] all_exclusions = set(self.exclude_entities.ids) for tablename, tableinfo in self.dbschema.items(): if { f"{self.name}.{tablename}.{column}" for column in tableinfo.keys() }.issubset(all_exclusions): table_exclusions.append(tablename) if table_exclusions: logger.info( f"[{self.name}] : These tables will be excluded :" + (", ".join(table_exclusions)) ) else: logger.info(f"[{self.name}] : No tables will be excluded") logger.success(f"[{self.name}] : Exclusions Calculated") logger.debug(f"[{self.name}] : Generating Custom Descriptions ...") engine = create_engine(self.dsn.unicode_string()) assert isinstance(engine, Engine) temp_db = SQLDatabase( engine=engine, ignore_tables=table_exclusions, view_support=True, ) table_descriptor: dict[str, BaseTabDescriptor] = {} table_descriptions = {} for table in temp_db._metadata.sorted_tables: if table.name in table_exclusions: continue constraints = { col.name for con in table.constraints for col in con.columns # type: ignore } col_descriptor: dict[str, BaseColDescriptor] = {} col_enums = [] for col in table._columns: # type: ignore if (col.name not in constraints) and ( f"{self.name}.{table.name}.{col.name}" in all_exclusions ): logger.info( f"[{self.name}.{table.name}] :\ Removing column {col.name}" ) table._columns.remove(col) # type: ignore else: if (table.name not in self.descriptor) or ( col.name not in self.descriptor[table.name]["col_descriptor"] ): if ( self.enum_limit > 0 ) and ( col.type.python_type == str ): col_enums.append( sqe.select( sqe.literal(col.name, VARCHAR).label( "COLNAME" ), sqe.case( ( sqe.select( func.count(sqe.distinct(col)) < self.enum_limit ).label("COLCOUNT"), col, ) ).label("COLVALS"), ).distinct() ) col_descriptor_map: BaseColDescriptor = { "col_type": str(col.type), "col_nullable": col.nullable, "col_pk": col.primary_key, "col_defval": col.default, "col_comment": col.comment, "col_enum_vals": None, "col_description": ( ( self.data_dictionary["tables"][table.name][ "columns" ][col.name]["description"] ) if ( (self.data_dictionary) and ( table.name in self.data_dictionary[ "tables" ] ) and ( col.name in self.data_dictionary[ "tables" ][table.name][ "columns" ] ) ) else None ), } else: col_descriptor_map = self.descriptor[table.name][ "col_descriptor" ][col.name] col_descriptor[col.name] = col_descriptor_map for colname, colvals in ( ( pd.read_sql(sql=sqe.union(*col_enums), con=engine) .replace("", np.nan) .dropna() .groupby("COLNAME", group_keys=False)["COLVALS"] .apply(list) .to_dict() ) if col_enums else {} ).items(): col_descriptor[colname]["col_enum_vals"] = colvals table_descriptor[table.name] = { "table_name": table.name, "table_creation_statement": str( CreateTable(table).compile(engine) ).rstrip(), "table_sample_rows": temp_db._get_sample_rows(table), "col_descriptor": col_descriptor, } logger.trace( f"[{self.name}] : Table descriptor created for {table.name}" + f"\n{table_descriptor[table.name]}" ) table_descriptions[table.name] = self.table_desc_template.format( **{ key: value for key, value in table_descriptor[table.name].items() if key in self.table_desc_template.input_variables } ) self.descriptor = table_descriptor logger.success(f"[{self.name}] : Custom Descriptions Generated") temp_db._custom_table_info = table_descriptions self.db = temp_db logger.success(f"[{self.name}] : Instantiated") DBNameDBMap = dict[DBName, Database] class Dataset(BaseModel): """ A dataset is a collection of databases """ model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) databases: DBNameDBMap dataset_schema: BaseDatasetSchema exclude_entities: EntitySet = EntitySet(ids=[], dataset_schema={}) data_dictionary: DatasetDataDictionary = {} enum_limit: int = 10 table_desc_template: SkipValidation[PromptTemplate] = ( ZeroShotPrompts.TABLE_DESCRIPTION_V3 ) # TODO @madhups - remove SkipValidation after Pydantic V2 support is added # to Langchain: https://github.com/langchain-ai/langchain/discussions/9337 def get_database(self, database_name: str) -> Database: """ Utility function to fetch a specific database from a dataset """ return self.databases[database_name] def filter( self, filters: list[str], filter_type: typing.Literal["only", "exclude"], prune: bool = False, ) -> "Dataset": """ Applies a filter to this dtaset and provides a new instance. """ databases = { k: v.filter( filters, filter_type ) for k, v in self.databases.items() } if prune: databases = {k: v for k, v in databases if v.db.table_info} return Dataset( databases=databases, dataset_schema=self.dataset_schema, exclude_entities=self.exclude_entities, data_dictionary=self.data_dictionary, enum_limit=self.enum_limit, table_desc_template=self.table_desc_template, ) @property def list_databases(self) -> list[str]: """ Returns a list of databases in this dataset """ return list(self.databases.keys()) @field_serializer("table_desc_template") def serialize_prompt_template(self, table_desc_template: PromptTemplate, _info ): """ Langchain Serializer """ return { "template": table_desc_template.template, "template_format": table_desc_template.template_format, } @classmethod def from_connection_strings( cls, name_connstr_map: dict[str, str], exclude_entities: list[str] = [], **kwargs, ) -> Self: """ Utility function to create a dataset from a name -> conn_str mapping. """ dataset_schema = { db_name: Database.fetch_schema(db_name, AllowedDSN(db_connstr))[db_name] for db_name, db_connstr in name_connstr_map.items() } parsed_exclude_entities = EntitySet( ids=exclude_entities, dataset_schema=dataset_schema ) data_dictionary = kwargs.pop("data_dictionary", dict()) databases = { db_name: Database.from_connection_string( name=db_name, connection_string=db_connstr, exclude_entities=parsed_exclude_entities.filter( key="database", value=db_name ), schema={db_name: dataset_schema[db_name]}, data_dictionary=data_dictionary.get(db_name), **kwargs, ) for db_name, db_connstr in name_connstr_map.items() } return cls( databases=databases, dataset_schema=dataset_schema, exclude_entities=parsed_exclude_entities, data_dictionary=data_dictionary, **kwargs, )