generator/dryrun.py (247 lines of code) (raw):
"""Dry Run method to get BigQuery metadata."""
import json
from enum import Enum
from functools import cached_property
from typing import Optional
from urllib.request import Request, urlopen
import google.auth
from google.auth.transport.requests import Request as GoogleAuthRequest
from google.cloud import bigquery
from google.oauth2.id_token import fetch_id_token
DRY_RUN_URL = (
"https://us-central1-moz-fx-data-shared-prod.cloudfunctions.net/bigquery-etl-dryrun"
)
def credentials(auth_req: Optional[GoogleAuthRequest] = None):
"""Get GCP credentials."""
auth_req = auth_req or GoogleAuthRequest()
creds, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
creds.refresh(auth_req)
return creds
def id_token():
"""Get token to authenticate against Cloud Function."""
auth_req = GoogleAuthRequest()
creds = credentials(auth_req)
if hasattr(creds, "id_token"):
# Get token from default credentials for the current environment created via Cloud SDK run
id_token = creds.id_token
else:
# If the environment variable GOOGLE_APPLICATION_CREDENTIALS is set to service account JSON file,
# then ID token is acquired using this service account credentials.
id_token = fetch_id_token(auth_req, DRY_RUN_URL)
return id_token
class DryRunError(Exception):
"""Exception raised on dry run errors."""
def __init__(self, message, error, use_cloud_function, table_id):
"""Initialize DryRunError."""
super().__init__(message)
self.error = error
self.use_cloud_function = use_cloud_function
self.table_id = table_id
def __reduce__(self):
"""
Override to ensure that all parameters are being passed when pickling.
Pickling happens when passing exception between processes (e.g. via multiprocessing)
"""
return (
self.__class__,
self.args + (self.error, self.use_cloud_function, self.table_id),
)
class Errors(Enum):
"""DryRun errors that require special handling."""
READ_ONLY = 1
DATE_FILTER_NEEDED = 2
DATE_FILTER_NEEDED_AND_SYNTAX = 3
PERMISSION_DENIED = 4
class DryRunContext:
"""DryRun builder class."""
def __init__(
self,
use_cloud_function=False,
id_token=None,
credentials=None,
dry_run_url=DRY_RUN_URL,
):
"""Initialize dry run instance."""
self.use_cloud_function = use_cloud_function
self.dry_run_url = dry_run_url
self.id_token = id_token
self.credentials = credentials
def create(
self,
sql=None,
project="moz-fx-data-shared-prod",
dataset=None,
table=None,
):
"""Initialize a DryRun instance."""
return DryRun(
use_cloud_function=self.use_cloud_function,
id_token=self.id_token,
credentials=self.credentials,
sql=sql,
project=project,
dataset=dataset,
table=table,
dry_run_url=self.dry_run_url,
)
class DryRun:
"""Dry run SQL."""
def __init__(
self,
use_cloud_function=False,
id_token=None,
credentials=None,
sql=None,
project="moz-fx-data-shared-prod",
dataset=None,
table=None,
dry_run_url=DRY_RUN_URL,
):
"""Initialize dry run instance."""
self.sql = sql
self.use_cloud_function = use_cloud_function
self.project = project
self.dataset = dataset
self.table = table
self.dry_run_url = dry_run_url
self.id_token = id_token
self.credentials = credentials
@cached_property
def client(self):
"""Get BigQuery client instance."""
return bigquery.Client(credentials=self.credentials)
@cached_property
def dry_run_result(self):
"""Return the dry run result."""
try:
if self.use_cloud_function:
json_data = {
"query": self.sql or "SELECT 1",
"project": self.project,
"dataset": self.dataset or "telemetry",
}
if self.table:
json_data["table"] = self.table
r = urlopen(
Request(
self.dry_run_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.id_token}",
},
data=json.dumps(json_data).encode("utf8"),
method="POST",
)
)
return json.load(r)
else:
query_schema = None
referenced_tables = []
table_metadata = None
if self.sql:
job_config = bigquery.QueryJobConfig(
dry_run=True,
use_query_cache=False,
query_parameters=[
bigquery.ScalarQueryParameter(
"submission_date", "DATE", "2019-01-01"
)
],
)
if self.project:
job_config.connection_properties = [
bigquery.ConnectionProperty(
"dataset_project_id", self.project
)
]
job = self.client.query(self.sql, job_config=job_config)
query_schema = (
job._properties.get("statistics", {})
.get("query", {})
.get("schema", {})
)
referenced_tables = [
ref.to_api_repr() for ref in job.referenced_tables
]
if (
self.project is not None
and self.table is not None
and self.dataset is not None
):
table = self.client.get_table(
f"{self.project}.{self.dataset}.{self.table}"
)
table_metadata = {
"tableType": table.table_type,
"friendlyName": table.friendly_name,
"schema": {
"fields": [field.to_api_repr() for field in table.schema]
},
}
return {
"valid": True,
"referencedTables": referenced_tables,
"schema": query_schema,
"tableMetadata": table_metadata,
}
except Exception as e:
print(f"ERROR {e}")
return None
def get_schema(self):
"""Return the query schema by dry running the SQL file."""
self.validate()
if (
self.dry_run_result
and self.dry_run_result["valid"]
and "schema" in self.dry_run_result
):
return self.dry_run_result["schema"]["fields"]
return []
def get_table_schema(self):
"""Return the schema of the provided table."""
self.validate()
if (
self.dry_run_result
and self.dry_run_result["valid"]
and "tableMetadata" in self.dry_run_result
):
return self.dry_run_result["tableMetadata"]["schema"]["fields"]
return []
def get_table_metadata(self):
"""Return table metadata."""
self.validate()
if (
self.dry_run_result
and self.dry_run_result["valid"]
and "tableMetadata" in self.dry_run_result
):
return self.dry_run_result["tableMetadata"]
return {}
def validate(self):
"""Dry run the provided SQL file and check if valid."""
dry_run_error = DryRunError(
"Error when dry running SQL",
self.get_error(),
self.use_cloud_function,
self.table,
)
if self.dry_run_result is None:
raise dry_run_error
if self.dry_run_result["valid"]:
return True
elif self.get_error() == Errors.READ_ONLY:
# We want the dryrun service to only have read permissions, so
# we expect CREATE VIEW and CREATE TABLE to throw specific
# exceptions.
return True
elif self.get_error() == Errors.DATE_FILTER_NEEDED:
# With strip_dml flag, some queries require a partition filter
# (submission_date, submission_timestamp, etc.) to run
return True
else:
print("ERROR\n", self.dry_run_result["errors"])
raise dry_run_error
def errors(self):
"""Dry run the provided SQL file and return errors."""
if self.dry_run_result is None:
return []
return self.dry_run_result.get("errors", [])
def get_error(self) -> Optional[Errors]:
"""Get specific errors for edge case handling."""
errors = self.errors()
if len(errors) != 1:
return None
error = errors[0]
if error and error.get("code") in [400, 403]:
error_message = error.get("message", "")
if (
"does not have bigquery.tables.create permission for dataset"
in error_message
or "Permission bigquery.tables.create denied" in error_message
or "Permission bigquery.datasets.update denied" in error_message
):
return Errors.READ_ONLY
if "without a filter over column(s)" in error_message:
return Errors.DATE_FILTER_NEEDED
if (
"Syntax error: Expected end of input but got keyword WHERE"
in error_message
):
return Errors.DATE_FILTER_NEEDED_AND_SYNTAX
if (
"Permission bigquery.tables.get denied on table" in error_message
or "User does not have permission to query table" in error_message
):
return Errors.PERMISSION_DENIED
return None