tools/dlp-to-data-catalog/dlp/preprocess.py (292 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processes input data to fit to DLP inspection standards."""
import dataclasses
from enum import Enum
from typing import List, Tuple, Dict
from google.api_core.exceptions import NotFound
from google.cloud import bigquery, dlp_v2
from google.cloud.sql.connector import Connector
from sqlalchemy import create_engine, MetaData, Table, func, select, inspect
@dataclasses.dataclass
class Bigquery:
"""Represents a connection to a Google BigQuery dataset and table."""
bq_client: bigquery.Client
dataset: str
table: str
@dataclasses.dataclass
class CloudSQL:
"""Represents a connection to a Google CloudSQL."""
connector: Connector
connection_name: str
service_account: str
db_name: str
table: str
driver: str
connection_type: str
class Database(Enum):
"""Represents available sources for database connections."""
BIGQUERY = "bigquery"
CLOUDSQL = "cloudsql"
class Preprocessing:
"""Converts input data into Data Loss Prevention tables."""
def __init__(
self,
source: str,
project: str,
zone: str,
**preprocess_args
):
"""Initializes `Preprocessing` class with arguments.
Args:
source (str): The name of the source of data used.
project (str): The name of the Google Cloud Platform project.
zone (str): The name of the zone.
**preprocess_args: Additional arguments for preprocessing.
Supported arguments are:
- bigquery_args(Dict):
- dataset (str): The name of the BigQuery dataset.
- table (str, optional): The name of the BigQuery table.
If not provided, the entire dataset is scanned.
Optional. Defaults to None.
- cloudsql_args(Dict):
- instance (str): Name of the database instance.
- service_account(str): Service account email to be used.
- db_name(str): The name of the database.
- table (str): The name of the table.
- db_type(str): The type of the database.
e.g. postgres, mysql.
"""
self.source = Database(source)
self.project = project
self.zone = zone
if self.source == Database.BIGQUERY:
# Handle BigQuery source.
bigquery_args = preprocess_args.get("bigquery_args", {})
self.bigquery = Bigquery(bigquery.Client(project=project),
bigquery_args["dataset"],
bigquery_args["table"])
elif self.source == Database.CLOUDSQL:
# Handle Cloud SQL source.
cloudsql_args = preprocess_args.get("cloudsql_args", {})
instance = cloudsql_args["instance"]
db_type = cloudsql_args["db_type"]
# Determine the appropriate database driver and connection
# name based on db_type.
if db_type == "mysql":
driver = "pymysql"
connection_name = f"mysql+{driver}"
elif db_type == "postgres":
driver = "pg8000"
connection_name = f"postgresql+{driver}"
self.cloudsql = CloudSQL(
Connector(),
f"{project}:{zone}:{instance}",
cloudsql_args["service_account"],
cloudsql_args["db_name"],
cloudsql_args["table"],
driver,
connection_name)
def get_connection(self):
"""Returns a connection to the database.
Returns:
A connection object that can be used to execute queries.
"""
connector = self.cloudsql.connector.connect(
self.cloudsql.connection_name,
self.cloudsql.driver,
enable_iam_auth=True,
user=self.cloudsql.service_account,
db=self.cloudsql.db_name
)
return connector
def get_cloudsql_tables(self):
"""Returns a list of all tables in the CloudSQL database.
Returns:
A list of table names.
"""
# Create a database engine instance.
engine = create_engine(
f'{self.cloudsql.connection_type}://', creator=self.get_connection)
table_names = inspect(engine).get_table_names()
return table_names
def get_cloudsql_data(
self,
table: str,
batch_size: int,
start_index: int
) -> Tuple[List, List]:
"""Retrieves the schema and content of a table from CloudSQL.
Args:
table (str): The name of the table.
batch_size (int): The block of cells to be analyzed.
start_index (int): The starting index of each block to be analyzed.
Returns:
Tuple(List, List): A tuple containing the schema and content
as a List.
"""
# Create a database engine instance.
engine = create_engine(
f'{self.cloudsql.connection_type}://', creator=self.get_connection)
# Create a Metadata and Table instance.
metadata = MetaData()
table = Table(table, metadata, extend_existing=True,
autoload_with=engine)
num_columns = len(table.columns.keys())
# Get table schema.
schema = [column.name for column in table.columns]
# Get table contents.
with engine.connect() as connection:
query = table.select().with_only_columns(table.columns) \
.limit(int(batch_size/num_columns)) \
.offset(int(start_index/num_columns))
content = list(connection.execute(query).fetchall())
return schema, content
def get_bigquery_tables(self, dataset: str) -> List[str]:
"""Constructs a list of table names from a BigQuery dataset.
Args:
Dataset (str): Name of the dataset in BigQuery.
Returns:
List of tablenames.
"""
dataset_tables = list(self.bigquery.bq_client.list_tables(dataset))
table_names = [table.table_id for table in dataset_tables]
return table_names
def fetch_rows(
self,
table_bq: bigquery.table.Table,
start_index: int,
batch_size: int
) -> List[Dict]:
"""Fetches a batch of rows from a BigQuery table.
Args:
table_bq (bigquery.table.Table) : The path of the table
where the data is fetched.
start_index (int) : The starting index of each block to
be analyzed.
batch_size (int) : The block of cells to be analyzed.
Returns:
List[Dict]: A list of rows, where each row is a tuple
containing the values for each field in the table schema.
"""
content = []
num_columns = len(table_bq.schema)
rows_iter = self.bigquery.bq_client.list_rows(
table=table_bq,start_index=int(start_index/num_columns),
max_results=int(batch_size/num_columns))
if not rows_iter.total_rows:
print(f"""The Table {table_bq.table_id} is empty. Please populate
the table and try again.""")
else:
for row in rows_iter:
content.append(tuple(row))
return content
def get_table_schema(self, table_id: str) -> Tuple[List, List, List]:
"""Generates a schema for a given table ID.
Args:
table_id (str): The ID of the table for which the schema needs
to be generated.
Returns:
tuple: A tuple containing three lists - schema, nested_columns,
and record_columns.
- schema (list): The list of fields in the schema.
- nested_columns (list): A list with the columns of the
nested columns.
- record_columns (list): The columns with the record type.
"""
schema = []
nested_columns = []
record_columns = []
fields = table_id.schema
for field in fields:
record, nested, main_field = self.get_field(field)
if nested:
record_columns.append(main_field)
nested_columns.append(record)
else:
schema.append(record)
return schema, nested_columns, record_columns
def get_field(self, field):
"""Generates a field for the given field object.
Args:
field: The field object for which the field needs to be generated.
Returns:
tuple: A tuple containing three values - record, nested, and
main_cell.
- record: The generated field or list of nested fields.
- nested (bool): Indicates if the field is nested or not.
- main_cell: The main field associated with the nested fields.
"""
# Checks if the field has nested columns.
if field.field_type == "RECORD":
field_names = []
for subfield in field.fields:
main_cell = field.name
cell = f"{field.name}.{subfield.name}"
# Checks if the field has nested fields within.
if subfield.field_type == "RECORD":
field_names.append(self.get_field(subfield))
else:
field_names.append(cell)
return field_names, True, main_cell
return field.name, False, False
def get_query(
self,
table_bq: bigquery.table.Table,
query_args: Dict) -> str:
"""Creates a SQL query as a string.
Args:
table_bq (bigquery.table.Table): The fully qualified name
of the BigQuery table.
query_args (Dict) :
columns_selected (str): The string with the selected columns.
unnest (str): The unnest string for the query.
limit (int): The maximum number of rows to
retrieve in each block.
offset(int): The starting index of the rows to retrieve.
Returns:
str: SQL query as a string.
"""
query = f"""SELECT {query_args['columns_selected']}
FROM `{table_bq}`,
{query_args['unnest']} LIMIT {query_args['limit']}
OFFSET {query_args['offset']}"""
return query
def get_nested_types(self, table_bq: bigquery.table.Table) -> List[str]:
""" Gets the field modes of the selected table.
Args:
table_bq (bigquery.table.Table): The fully qualified
name of the BigQuery table.
Returns:
List: A complete list with the field modes of the columns.
"""
nested_types = [field.mode for field in table_bq.schema]
return nested_types
def get_rows_query(
self,
nested_args: Dict,
table_bq: bigquery.table.Table,
batch_size: int,
start_index: int
) -> List[Dict]:
"""Retrieves the content of the table.
Args:
nested_args (Dict):
table_schema (List[Dict]): The schema of a BigQuery table.
nested_columns (List[Dict]): A list with the columns
of the nested columns.
record_columns (List[Dict]): The columns with the record type.
table_bq (bigquery.table.Table): The fully qualified name
of the BigQuery table.
batch_size (int): The block of cells to be analyzed.
start_index (int): The starting index of each block to be analyzed.
Returns:
List[Dict]: The content of the BigQuery table.
"""
nested_types = self.get_nested_types(table_bq)
if "REPEATED" in nested_types:
bq_schema = nested_args["table_schema"] \
+ nested_args["record_columns"]
num_columns = len(bq_schema)
columns_selected = ", ".join(str(column) for column in bq_schema)
unnest = f"UNNEST ({nested_args['record_columns'][0]})"
else:
bq_schema = nested_args["table_schema"] \
+ nested_args["nested_columns"]
num_columns = len(bq_schema)
columns_selected = ", ".join(str(column) for column in bq_schema)
unnest = (
f"""UNNEST ([{nested_args['record_columns'][0]}])
as {nested_args['record_columns'][0]} """
)
# Generate the SQL query using the selected columns,
# table, unnest, limit, and offset. Calculate the limit and offset for
# the SQL query based on the block size.
sql_query = self.get_query(
table_bq,
{
"columns_selected": columns_selected,
"unnest": unnest,
"offset": int(start_index/num_columns),
"limit": int(batch_size/num_columns)
})
query_job = self.bigquery.bq_client.query(sql_query)
query_results = query_job.result()
bq_rows_content = [tuple(dict(row).values()) for row in query_results]
return bq_rows_content
def get_data_types(self, table_id: str) -> List:
""" Gets the data types of the selected table.
Args:
table_id (str): The fully qualified name of the BigQuery table.
Returns:
List: A complete list with the data types of the columns.
"""
return [field.field_type for field in table_id.schema]
def flatten_list(self, unflattened_list: List) -> List:
"""
Recursively flattens a nested list and returns a flattened list.
Args:
list (list): The input list that needs to be flattened.
Returns:
list: The flattened list.
"""
# Create an empty list to store the flattened elements.
flattened = []
# Iterate through each element in the list.
for element in unflattened_list:
# If the element is a list, recursively flatten the list.
if isinstance(element, list):
flattened.extend(self.flatten_list(element))
else:
# If the element is not a list, add it to the flattened list.
flattened.append(element)
# Return the flattened list.
return flattened
def get_bigquery_data(
self,
table_id: str,
start_index: int,
batch_size: int
) -> Tuple[List[Dict], List[Dict]]:
"""Retrieves the schema and content of a BigQuery table.
Args:
table_id (str): The fully qualified name of the BigQuery table.
star_index (int): The starting index of each block to be analyzed.
batch_size (int): The block of cells to be analyzed.
Returns:
Tuple[List[Dict], List[Dict]]: A tuple containing the BigQuery
schema and content as a List of Dictionaries.
"""
try:
table_bq = self.bigquery.bq_client.get_table(table_id)
except NotFound as exc:
raise ValueError(f"Error retrieving table {table_id}.") from exc
dtypes = self.get_data_types(table_bq)
# Checks if there are nested fields in the schema.
if "RECORD" in dtypes:
table_schema, nested_columns, record_columns = (
self.get_table_schema(table_bq))
table_schema = self.flatten_list(table_schema)
nested_columns = self.flatten_list(nested_columns)
bq_schema = table_schema + nested_columns
bq_rows_content = self.get_rows_query(
{
"table_schema": table_schema,
"nested_columns": nested_columns,
"record_columns": record_columns
},
table_bq,
batch_size,
start_index)
else:
table_schema = table_bq.schema
bq_schema = [field.to_api_repr()["name"] for field in table_schema]
bq_rows_content = self.fetch_rows(
table_bq, start_index, batch_size)
return bq_schema, bq_rows_content
def convert_to_dlp_table(self, schema: List[Dict],
content: List[Dict]) -> dlp_v2.Table:
"""Converts a BigQuery table into a DLP table.
Converts a BigQuery table into a Data Loss Prevention table,
an object that can be inspected by Data Loss Prevention.
Args:
schema (List[Dict]): The schema of a BigQuery table.
content (List[Dict]): The content of a BigQuery table.
Returns:
A table object that can be inspected by Data Loss Prevention.
"""
table_dlp = dlp_v2.Table()
table_dlp.headers = [
{"name": schema_object} for schema_object in schema
]
rows = []
for row in content:
rows.append(dlp_v2.Table.Row(
values=[dlp_v2.Value(
string_value=str(cell_val)) for cell_val in row]))
table_dlp.rows = rows
return table_dlp
def get_tables_info(self) -> List[Tuple]:
"""Retrieves information about tables in a dataset from
BigQuery or CloudSQL.
This function is used to retrieve information necessary
for running the program with dataflow, as it enables subsequent
table fragmentation for parallelization.
Returns:
List[Tuple]: A list of tuples containing the table name
and the total number of cells.
"""
# Retrieve table names from either a specific table
# or all tables in a dataset.
if self.source == Database.BIGQUERY:
table_names = [self.bigquery.table] if self.bigquery.table \
else self.get_bigquery_tables(self.bigquery.dataset)
elif self.source == Database.CLOUDSQL:
table_names = [self.cloudsql.table] if self.cloudsql.table \
else self.get_cloudsql_tables()
tables = []
if self.source == Database.BIGQUERY:
for table_name in table_names:
# Get the table object from BigQuery.
table_bq = self.bigquery.bq_client.get_table(
f"{self.bigquery.dataset}.{table_name}")
# Calculate the total number of rows and columns in the table.
num_rows = table_bq.num_rows
dtypes = self.get_data_types(table_bq)
# Checks if there are nested fields in the schema.
if "RECORD" in dtypes:
table_schema, _, record_columns = (
self.get_table_schema(table_bq))
num_columns = len(table_schema + record_columns)
else:
num_columns = len(table_bq.schema)
# Append a tuple with the table name and the total number
# of cells to the list.
tables.append((table_name,num_rows*num_columns))
elif self.source == Database.CLOUDSQL:
# Create a database engine instance.
engine = create_engine(
f'{self.cloudsql.connection_type}://',
creator=self.get_connection
)
for table_name in table_names:
# Create a Metadata and Table instance.
metadata = MetaData()
table = Table(table_name, metadata, extend_existing=True,
autoload_with=engine)
num_columns = len(table.columns.keys())
# Get table contents.
with engine.connect() as connection:
count_query = select(
# pylint: disable=E1102
func.count("*")).select_from(table)
num_rows = connection.execute(count_query).scalar()
tables.append((table_name,num_rows*num_columns))
return tables
def get_dlp_table_per_block(
self,
batch_size: int,
table_name: str,
start_index: int
) -> dlp_v2.Table:
"""Constructs a DLP Table object, a partir de cada bloque de celdas.
Args:
batch_size (int): The block of cells to be analyzed.
table_name (str): The name of the table to be analyzed.
start_index (int): The starting index of each block to be analyzed.
Returns:
A Data Loss Prevention table object.
"""
if self.source == Database.BIGQUERY:
schema, content = self.get_bigquery_data(
f"{self.bigquery.dataset}.{table_name}",
start_index,
batch_size
)
elif self.source == Database.CLOUDSQL:
schema, content = self.get_cloudsql_data(
table_name,
batch_size,
start_index
)
return self.convert_to_dlp_table(schema, content)