in nl2sql_library/nl2sql/datasets/base.py [0:0]
def model_post_init(self, __context: object) -> None:
# pylint: disable=protected-access, too-many-branches
"""
Langchain's Post-Init method to properly validate DB
"""
logger.debug(f"[{self.name}] : Instantiating ...")
logger.debug(f"[{self.name}] : Calculating Exclusions ...")
table_exclusions = []
all_exclusions = set(self.exclude_entities.ids)
for tablename, tableinfo in self.dbschema.items():
if {
f"{self.name}.{tablename}.{column}"
for column in tableinfo.keys()
}.issubset(all_exclusions):
table_exclusions.append(tablename)
if table_exclusions:
logger.info(
f"[{self.name}] : These tables will be excluded :"
+ (", ".join(table_exclusions))
)
else:
logger.info(f"[{self.name}] : No tables will be excluded")
logger.success(f"[{self.name}] : Exclusions Calculated")
logger.debug(f"[{self.name}] : Generating Custom Descriptions ...")
engine = create_engine(self.dsn.unicode_string())
assert isinstance(engine, Engine)
temp_db = SQLDatabase(
engine=engine,
ignore_tables=table_exclusions,
view_support=True,
)
table_descriptor: dict[str, BaseTabDescriptor] = {}
table_descriptions = {}
for table in temp_db._metadata.sorted_tables:
if table.name in table_exclusions:
continue
constraints = {
col.name
for con in table.constraints
for col in con.columns # type: ignore
}
col_descriptor: dict[str, BaseColDescriptor] = {}
col_enums = []
for col in table._columns: # type: ignore
if (col.name not in constraints) and (
f"{self.name}.{table.name}.{col.name}" in all_exclusions
):
logger.info(
f"[{self.name}.{table.name}] :\
Removing column {col.name}"
)
table._columns.remove(col) # type: ignore
else:
if (table.name not in self.descriptor) or (
col.name
not in self.descriptor[table.name]["col_descriptor"]
):
if (
self.enum_limit > 0
) and (
col.type.python_type == str
):
col_enums.append(
sqe.select(
sqe.literal(col.name, VARCHAR).label(
"COLNAME"
),
sqe.case(
(
sqe.select(
func.count(sqe.distinct(col))
< self.enum_limit
).label("COLCOUNT"),
col,
)
).label("COLVALS"),
).distinct()
)
col_descriptor_map: BaseColDescriptor = {
"col_type": str(col.type),
"col_nullable": col.nullable,
"col_pk": col.primary_key,
"col_defval": col.default,
"col_comment": col.comment,
"col_enum_vals": None,
"col_description": (
(
self.data_dictionary["tables"][table.name][
"columns"
][col.name]["description"]
)
if (
(self.data_dictionary)
and (
table.name in self.data_dictionary[
"tables"
]
)
and (
col.name
in self.data_dictionary[
"tables"
][table.name][
"columns"
]
)
)
else None
),
}
else:
col_descriptor_map = self.descriptor[table.name][
"col_descriptor"
][col.name]
col_descriptor[col.name] = col_descriptor_map
for colname, colvals in (
(
pd.read_sql(sql=sqe.union(*col_enums), con=engine)
.replace("", np.nan)
.dropna()
.groupby("COLNAME", group_keys=False)["COLVALS"]
.apply(list)
.to_dict()
)
if col_enums
else {}
).items():
col_descriptor[colname]["col_enum_vals"] = colvals
table_descriptor[table.name] = {
"table_name": table.name,
"table_creation_statement": str(
CreateTable(table).compile(engine)
).rstrip(),
"table_sample_rows": temp_db._get_sample_rows(table),
"col_descriptor": col_descriptor,
}
logger.trace(
f"[{self.name}] : Table descriptor created for {table.name}"
+ f"\n{table_descriptor[table.name]}"
)
table_descriptions[table.name] = self.table_desc_template.format(
**{
key: value
for key, value in table_descriptor[table.name].items()
if key in self.table_desc_template.input_variables
}
)
self.descriptor = table_descriptor
logger.success(f"[{self.name}] : Custom Descriptions Generated")
temp_db._custom_table_info = table_descriptions
self.db = temp_db
logger.success(f"[{self.name}] : Instantiated")