def model_post_init()

in nl2sql_library/nl2sql/datasets/base.py [0:0]


    def model_post_init(self, __context: object) -> None:
        # pylint: disable=protected-access, too-many-branches
        """
        Langchain's Post-Init method to properly validate DB
        """
        logger.debug(f"[{self.name}] : Instantiating ...")
        logger.debug(f"[{self.name}] : Calculating Exclusions ...")
        table_exclusions = []
        all_exclusions = set(self.exclude_entities.ids)
        for tablename, tableinfo in self.dbschema.items():
            if {
                f"{self.name}.{tablename}.{column}"
                for column in tableinfo.keys()
            }.issubset(all_exclusions):
                table_exclusions.append(tablename)
        if table_exclusions:
            logger.info(
                f"[{self.name}] : These tables will be excluded :"
                + (", ".join(table_exclusions))
            )
        else:
            logger.info(f"[{self.name}] : No tables will be excluded")
        logger.success(f"[{self.name}] : Exclusions Calculated")
        logger.debug(f"[{self.name}] : Generating Custom Descriptions ...")
        engine = create_engine(self.dsn.unicode_string())
        assert isinstance(engine, Engine)
        temp_db = SQLDatabase(
            engine=engine,
            ignore_tables=table_exclusions,
            view_support=True,
        )
        table_descriptor: dict[str, BaseTabDescriptor] = {}
        table_descriptions = {}
        for table in temp_db._metadata.sorted_tables:
            if table.name in table_exclusions:
                continue
            constraints = {
                col.name
                for con in table.constraints
                for col in con.columns  # type: ignore
            }
            col_descriptor: dict[str, BaseColDescriptor] = {}
            col_enums = []
            for col in table._columns:  # type: ignore
                if (col.name not in constraints) and (
                    f"{self.name}.{table.name}.{col.name}" in all_exclusions
                ):
                    logger.info(
                        f"[{self.name}.{table.name}] :\
                            Removing column {col.name}"
                    )
                    table._columns.remove(col)  # type: ignore
                else:
                    if (table.name not in self.descriptor) or (
                        col.name
                        not in self.descriptor[table.name]["col_descriptor"]
                    ):
                        if (
                            self.enum_limit > 0
                            ) and (
                                col.type.python_type == str
                                ):
                            col_enums.append(
                                sqe.select(
                                    sqe.literal(col.name, VARCHAR).label(
                                        "COLNAME"
                                        ),
                                    sqe.case(
                                        (
                                            sqe.select(
                                                func.count(sqe.distinct(col))
                                                < self.enum_limit
                                            ).label("COLCOUNT"),
                                            col,
                                        )
                                    ).label("COLVALS"),
                                ).distinct()
                            )

                        col_descriptor_map: BaseColDescriptor = {
                            "col_type": str(col.type),
                            "col_nullable": col.nullable,
                            "col_pk": col.primary_key,
                            "col_defval": col.default,
                            "col_comment": col.comment,
                            "col_enum_vals": None,
                            "col_description": (
                                (
                                    self.data_dictionary["tables"][table.name][
                                        "columns"
                                    ][col.name]["description"]
                                )
                                if (
                                    (self.data_dictionary)
                                    and (
                                        table.name in self.data_dictionary[
                                            "tables"
                                            ]
                                        )
                                    and (
                                        col.name
                                        in self.data_dictionary[
                                            "tables"
                                            ][table.name][
                                            "columns"
                                        ]
                                    )
                                )
                                else None
                            ),
                        }
                    else:
                        col_descriptor_map = self.descriptor[table.name][
                            "col_descriptor"
                        ][col.name]
                    col_descriptor[col.name] = col_descriptor_map

            for colname, colvals in (
                (
                    pd.read_sql(sql=sqe.union(*col_enums), con=engine)
                    .replace("", np.nan)
                    .dropna()
                    .groupby("COLNAME", group_keys=False)["COLVALS"]
                    .apply(list)
                    .to_dict()
                )
                if col_enums
                else {}
            ).items():
                col_descriptor[colname]["col_enum_vals"] = colvals

            table_descriptor[table.name] = {
                "table_name": table.name,
                "table_creation_statement": str(
                    CreateTable(table).compile(engine)
                ).rstrip(),
                "table_sample_rows": temp_db._get_sample_rows(table),
                "col_descriptor": col_descriptor,
            }

            logger.trace(
                f"[{self.name}] : Table descriptor created for {table.name}"
                + f"\n{table_descriptor[table.name]}"
            )
            table_descriptions[table.name] = self.table_desc_template.format(
                **{
                    key: value
                    for key, value in table_descriptor[table.name].items()
                    if key in self.table_desc_template.input_variables
                }
            )

        self.descriptor = table_descriptor
        logger.success(f"[{self.name}] : Custom Descriptions Generated")
        temp_db._custom_table_info = table_descriptions
        self.db = temp_db
        logger.success(f"[{self.name}] : Instantiated")