nl2sql_library/nl2sql/datasets/custom.py (141 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Allows creating custom Datasets on a local, temp PGSQL/ MySQL instance
"""
import os
import typing
from functools import lru_cache
import numpy as np
import openpyxl
import pandas as pd
from google.api_core.exceptions import NotFound
from google.cloud import bigquery
from google.cloud.bigquery import SchemaField
from loguru import logger
from nl2sql.datasets.base import Dataset
@np.vectorize
def generate_pk_query(dataset_id: str,
tablename: str,
primary_key_column: str
) -> str:
"""
Generate DDL queries to add associated primary key columns to respective
tables.
Args:
dataset_id (str): Bigquery dataset id.
tablename (str): Bigquery table name.
primary_key_column (str): Name of the primary key column in table.
Returns:
query (str): DDL query to add primary key to table.
"""
query = (
f"ALTER TABLE `{dataset_id}.{tablename}` "
f"ADD PRIMARY KEY({primary_key_column}) NOT ENFORCED;"
)
return query
@np.vectorize
def generate_fk_query(
dataset_id: str, tablename: str, foreign_key_column: str, references: str
) -> str:
"""
Generate DDL queries to add associated foreign key columns to respective
tables and their references.
Args:
dataset_id (str): Bigquery dataset id.
tablename (str): Bigquery table name.
foreign_key_column (str): Name of the foreign key column in table.
references (str): Reference column for foreign key.
Returns:
query (str): DDL query to add foreign key to table.
"""
query = (
f"ALTER TABLE `{dataset_id}.{tablename}` "
f"ADD FOREIGN KEY({foreign_key_column}) "
f"REFERENCES `{dataset_id}`.{references.replace(' ','')} "
"NOT ENFORCED;"
)
return query
class CustomDataset:
"""
Instantiates a local database
"""
def __init__(self,
filepath: str,
project_id: str | None,
dataset_name: str
):
"""
Custom Dataset
Args:
filepath (str): Filepath for excel file having table schema & data.
project_id (str | None): GCP Project Id.
dataset_name (str): Bigquery dataset name to be created.
"""
self.filepath = filepath
self.dataset_name = dataset_name
self.client = bigquery.Client(project=project_id, location="US")
@classmethod
@lru_cache
def from_excel(
cls,
filepath: str,
project_id: str | None = None,
dataset_name: str = "custom_dataset",
) -> Dataset:
"""
Creates and returns a Dataset object based on the specified excel file.
Args:
filepath (str): File path where the input excel file is located.
project_id (str | None, optional):
GCP Project ID where bigquery dataset is created.
Defautls to Environment variable "GOOGLE_CLOUD_PROJECT".
dataset_name (str, optional):
Name of the Custom Bigquery Dataset.
Defaults to "custom_dataset".
Returns:
Dataset: A Dataset Object.
"""
if project_id is None:
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
custom = cls(
filepath=filepath, project_id=project_id, dataset_name=dataset_name
)
dataset_id = f"{project_id}.{dataset_name}"
try:
dataset = custom.client.get_dataset(dataset_id)
logger.info(f"Dataset {dataset_id} already exists.")
except NotFound:
dataset = bigquery.Dataset(dataset_id)
dataset = custom.client.create_dataset(dataset)
logger.success(f"Created dataset {dataset_id}.")
custom.create_tables()
custom.update_key_columns(dataset_id=dataset_id)
return Dataset.from_connection_strings(
name_connstr_map={
f"{dataset_name}": f"bigquery://{project_id}/{dataset_name}"
}
)
def generate_bigquery_schema(
self, table_df: pd.DataFrame
) -> typing.List[SchemaField]:
"""Generate a Bigquery compatible schema from Pandas Dataframe.
Args:
table_df (pd.DataFrame): Table Dataframe.
Returns:
typing.List[SchemaField]: Bigquery Compatibel Schema.
"""
type_mapping = {
"i": "INTEGER",
"u": "NUMERIC",
"b": "BOOLEAN",
"f": "FLOAT",
"O": "STRING",
"S": "STRING",
"U": "STRING",
"M": "TIMESTAMP",
}
schema = []
for column, dtype in table_df.dtypes.items():
val = table_df[column].iloc[0]
mode = "REPEATED" if isinstance(val, list) else "NULLABLE"
if isinstance(val, dict) or (
mode == "REPEATED" and isinstance(val[0], dict)
):
fields = self.generate_bigquery_schema(pd.json_normalize(val))
else:
fields = []
type_ = "RECORD" if fields else type_mapping.get(dtype.kind)
schema.append(
SchemaField(
name=column, # type: ignore
field_type=type_, # type: ignore
mode=mode,
fields=fields,
)
)
return schema
def create_tables(self):
"""
Create Tables in Bigquery based on the sheetname in excel file.
"""
workbook = openpyxl.load_workbook(self.filepath)
sheetnames = workbook.sheetnames
sheetnames = [
sheetname
for sheetname in sheetnames
if sheetname not in ["Primary Keys", "Foreign Keys"]
]
for sheetname in sheetnames:
table_id = f"{self.dataset_name}.{sheetname}"
table_df = pd.read_excel(self.filepath, sheet_name=sheetname)
table_df = table_df.convert_dtypes()
schema = self.generate_bigquery_schema(table_df)
job_config = bigquery.LoadJobConfig(
schema=schema, write_disposition="WRITE_TRUNCATE"
)
job = self.client.load_table_from_dataframe(
table_df, table_id, job_config=job_config
)
job.result()
logger.success(f"Created table {table_id}")
def update_key_columns(self, dataset_id):
"""
Update Key columns of tables present in the dataset.
"""
try:
pkey = pd.read_excel(self.filepath, sheet_name="Primary Keys")
fkey = pd.read_excel(self.filepath, sheet_name="Foreign Keys")
pkey["Query"] = generate_pk_query(
dataset_id, pkey["Table"], pkey["Primary Key"]
)
fkey["Query"] = generate_fk_query(
dataset_id, fkey["Table"],
fkey["Foreign Key"],
fkey["References"]
)
for pk_query in pkey["Query"].tolist():
self.client.query(pk_query)
for fk_query in fkey["Query"].tolist():
self.client.query(fk_query)
except ValueError as err:
logger.error(f"Sheetname value error: {err}")