common/py_libs/bq_helper.py (229 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Library for BigQuery related functions."""
from collections import abc
from enum import Enum
import logging
import pathlib
import typing
from google.api_core import retry
from google.cloud import bigquery
from google.cloud.exceptions import BadRequest
from google.cloud.exceptions import Conflict
from google.cloud.exceptions import Forbidden
from google.cloud.exceptions import NotFound
from common.py_libs import constants
from common.py_libs import cortex_exceptions as exc
from common.py_libs.bq_materializer import add_cluster_to_table_def
from common.py_libs.bq_materializer import add_partition_to_table_def
logger = logging.getLogger(__name__)
def execute_sql_file(bq_client: bigquery.Client,
sql_file: typing.Union[str, pathlib.Path]) -> None:
"""Executes a Bigquery sql file."""
# TODO: Convert sql_file from str to Path.
with open(sql_file, mode="r", encoding="utf-8") as sqlf:
sql_str = sqlf.read()
logging.debug("Executing SQL: %s", sql_str)
try:
query_job = bq_client.query(sql_str)
# Let's wait for query to complete.
_ = query_job.result()
except BadRequest:
logging.error("Error when executing SQL:\n%s", sql_str)
raise
def table_exists(bq_client: bigquery.Client, full_table_name: str) -> bool:
"""Checks if a given table exists in BigQuery."""
try:
bq_client.get_table(full_table_name)
return True
except NotFound:
return False
def dataset_exists(bq_client: bigquery.Client, full_dataset_name: str) -> bool:
"""Checks if a given dataset exists in BigQuery."""
try:
bq_client.get_dataset(full_dataset_name)
return True
except NotFound:
return False
DatasetExistence = Enum("DatasetExistence",
["NOT_EXISTS",
"EXISTS_IN_LOCATION",
"EXISTS_IN_ANOTHER_LOCATION"])
def dataset_exists_in_location(bq_client: bigquery.Client,
full_dataset_name: str,
location: str) -> DatasetExistence:
"""Checks if a given dataset exists in BigQuery in a location."""
try:
dataset = bq_client.get_dataset(full_dataset_name)
return (DatasetExistence.EXISTS_IN_LOCATION
if dataset.location.lower() == location.lower() # type: ignore
else DatasetExistence.EXISTS_IN_ANOTHER_LOCATION)
except NotFound:
return DatasetExistence.NOT_EXISTS
def _wait_for_bq_jobs(jobs: typing.List[typing.Union[bigquery.CopyJob,
bigquery.LoadJob]],
continue_if_failed: bool):
"""Waits for BigQuery jobs to finish."""
# We can simply wait for them in a for loop
# because we need all of them to finish.
for job in jobs:
try:
job.result(retry=retry.Retry(deadline=60))
logging.info("✅ Table %s has been loaded.", job.destination)
except Conflict:
logging.warning("⚠️ Table %s already exists. Skipping it.",
job.destination)
except Exception: # pylint: disable=bare-except, broad-exception-caught
logging.error(
"⛔️ Failed to load table %s.\n",
job.destination,
exc_info=True)
if not continue_if_failed:
raise
def load_tables(bq_client: bigquery.Client,
sources: typing.Union[str, typing.List[str]],
target_tables: typing.Union[str, typing.List[str]],
location: str,
continue_if_failed: bool = False,
skip_existing_tables: bool = False,
write_disposition: str = bigquery.WriteDisposition.WRITE_EMPTY,
parallel_jobs: int = 5):
"""Loads data to multiple BigQuery tables.
Args:
bq_client (bigquery.Client): BigQuery client to use.
sources (str | list[str]): data source URI or name.
Supported sources:
- BigQuery table name as project.dataset.table
- Any URI supported by load_table_from_uri
for avro, csv, json and parquet files.
target_tables (str | list[str]): full target tables names as
"project.dataset.table".
location (str): BigQuery location.
continue_if_failed (bool): continue loading tables if some jobs fail.
skip_existing_tables (bool): Skip tables that already exist.
Defaults to False.
write_disposition (bigquery.WriteDisposition): write disposition,
Defaults to WRITE_EMPTY (skip if has data).
parallel_jobs (int): maximum number of parallel jobs. Defaults to 5.
Raises:
ValueError: If the number of source URIs is not equal to the number
of target tables, or if an unsupported source format
is provided.
"""
if not isinstance(sources, abc.Sequence):
sources = [sources]
if not isinstance(target_tables, abc.Sequence):
target_tables = [target_tables]
if len(target_tables) != len(sources):
raise ValueError(("Number of source URIs must be equal to "
"number of target tables."))
jobs = []
for index, source in enumerate(sources):
target = target_tables[index]
logging.info("Loading table %s from %s.", target, source)
if skip_existing_tables and table_exists(bq_client, target):
logging.warning("⚠️ Table %s already exists. Skipping it.", target)
continue
if "://" in source:
ext = source.split(".")[-1].lower()
if ext == "avro":
source_format = bigquery.SourceFormat.AVRO
elif ext == "parquet":
source_format = bigquery.SourceFormat.PARQUET
elif ext == "csv":
source_format = bigquery.SourceFormat.CSV
elif ext == "json":
source_format = bigquery.SourceFormat.NEWLINE_DELIMITED_JSON
else:
raise ValueError((f"Extension `{ext}` "
"is an unsupported source format."))
job_config = bigquery.LoadJobConfig(
autodetect=True,
source_format=source_format,
write_disposition=write_disposition,
)
load_job = bq_client.load_table_from_uri(
source_uris=source,
destination=target,
job_config=job_config,
location=location,
retry=retry.Retry(deadline=60),
)
else:
job_config = bigquery.CopyJobConfig(
write_disposition=write_disposition)
load_job = bq_client.copy_table(source,
target,
location=location,
job_config=job_config,
retry=retry.Retry(deadline=60))
jobs.append(load_job)
# If reached parallel_jobs number, wait for them to finish.
if len(jobs) >= parallel_jobs:
_wait_for_bq_jobs(jobs, continue_if_failed)
jobs.clear()
# Wait for the rest of jobs to finish.
_wait_for_bq_jobs(jobs, continue_if_failed)
def create_dataset(bq_client: bigquery.Client,
dataset_name: str,
location: str,
suppress_success_logging: bool = False) -> None:
"""Creates a BigQuery dataset."""
dataset_ref = bigquery.Dataset(dataset_name)
dataset_ref.location = location
try:
bq_client.create_dataset(dataset_ref, timeout=30)
if not suppress_success_logging:
logging.info("✅ Dataset %s has been created in %s.",
dataset_name,
location)
except Conflict:
logging.warning("⚠️ Dataset %s already exists in %s. Skipping it.",
dataset_name,
location)
except Exception:
logging.error("⛔️ Failed to create dataset %s in %s.", dataset_name,
location,
exc_info=True)
raise
def create_table(bq_client: bigquery.Client,
full_table_name: str,
schema_tuples_list: list[tuple[str, str]],
exists_ok=False) -> None:
"""Creates a BigQuery table based on given schema."""
create_table_from_schema(
bq_client,
full_table_name,
[bigquery.SchemaField(t[0], t[1]) for t in schema_tuples_list],
None,
None,
exists_ok
)
def create_table_from_schema(bq_client: bigquery.Client,
full_table_name: str,
schema: typing.List[bigquery.SchemaField],
partition_details=None,
cluster_details=None,
exists_ok=False) -> None:
"""Creates a table in BigQuery. Skips creation if it exists.
Args:
bq_client (bigquery.Client): BQ client.
full_table_name (str): Full table name (project.dataset.table).
schema (list[bigquery.SchemaField]): BQ schema.
partition_details (dict): Partition details.
cluster_details (dict): Clustering details.
exists_ok (Optional[bool]):
Allow the table to exist prior to this call. Defaults to False.
"""
logging.info("Creating table %s", full_table_name)
project, dataset_id, table_id = full_table_name.split(".")
table_ref = bigquery.TableReference(
bigquery.DatasetReference(project, dataset_id), table_id)
table = bigquery.Table(
table_ref,
schema=schema)
if partition_details:
table = add_partition_to_table_def(table, partition_details)
if cluster_details:
table = add_cluster_to_table_def(table, cluster_details)
bq_client.create_table(table, exists_ok=exists_ok)
def get_table_list(bq_client: bigquery.Client,
project_id: str,
dataset_name: str) -> typing.List[str]:
ds_ref = bigquery.DatasetReference(project_id, dataset_name)
return [t.table_id for t in bq_client.list_tables(ds_ref)]
def delete_table(bq_client: bigquery.Client, full_table_name: str) -> None:
""" Calls the BQ API to delete the table, returns nothing """
logger.info("Deleting table `%s`.", full_table_name)
bq_client.delete_table(full_table_name, not_found_ok=True)
def copy_dataset(bq_client: bigquery.Client,
source_project: str,
source_dataset: str,
target_project: str,
target_dataset: str,
location: str,
skip_existing_tables: bool = False,
write_disposition: str = (
bigquery.WriteDisposition.WRITE_EMPTY)):
"""Copies all tables from source dataset to target.
Args:
bq_client (bigquery.Client): BigQuery client to use.
source_project (str): Source project.
source_dataset (str): Source dataset. Must exist in specified location.
target_project (str): Target project.
target_dataset (str): Target dataset. Must exist in specified location.
location (str): BigQuery location.
skip_existing_tables (bool): Skip tables that already exist.
Defaults to False.
write_disposition (bigquery.WriteDisposition): Write disposition,
Defaults to WRITE_EMPTY (skip if has data).
"""
logging.info("Copying tables from `%s.%s` to `%s.%s`.",
source_project, source_dataset,
target_project, target_dataset)
tables = get_table_list(bq_client, source_project, source_dataset)
source_tables = [f"{source_project}.{source_dataset}.{t}" for t in tables]
target_tables = [f"{target_project}.{target_dataset}.{t}" for t in tables]
load_tables(bq_client,
source_tables, target_tables, location,
skip_existing_tables=skip_existing_tables,
write_disposition=write_disposition)
def label_dataset(bq_client: bigquery.Client,
dataset: bigquery.Dataset) -> None:
"""
Adds Cortex label to BigQuery Dataset.
Only updates dataset if label does not exist already in dataset.
Args:
bq_client (bigquery.Client): BigQuery client to use.
dataset (bigquery.Dataset): Dataset to label.
Raises:
NotFoundCError: If dataset does not exist.
"""
try:
labels = dataset.labels or {}
update_labels = False
for key, value in constants.BQ_DATASET_LABEL.items():
if key not in labels:
labels[key] = value
update_labels = True
if update_labels:
dataset.labels = labels
bq_client.update_dataset(dataset, ["labels"])
except NotFound:
error_msg = (
f"Dataset {dataset.project}.{dataset.dataset_id} not found.")
raise exc.NotFoundCError(error_msg) from NotFound
except Forbidden:
logger.info(
"Permission to tag %s.%s denied. Skipping tagging dataset.",
dataset.project,
dataset.dataset_id)