generator/dryrun.py (247 lines of code) (raw):

"""Dry Run method to get BigQuery metadata.""" import json from enum import Enum from functools import cached_property from typing import Optional from urllib.request import Request, urlopen import google.auth from google.auth.transport.requests import Request as GoogleAuthRequest from google.cloud import bigquery from google.oauth2.id_token import fetch_id_token DRY_RUN_URL = ( "https://us-central1-moz-fx-data-shared-prod.cloudfunctions.net/bigquery-etl-dryrun" ) def credentials(auth_req: Optional[GoogleAuthRequest] = None): """Get GCP credentials.""" auth_req = auth_req or GoogleAuthRequest() creds, _ = google.auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"] ) creds.refresh(auth_req) return creds def id_token(): """Get token to authenticate against Cloud Function.""" auth_req = GoogleAuthRequest() creds = credentials(auth_req) if hasattr(creds, "id_token"): # Get token from default credentials for the current environment created via Cloud SDK run id_token = creds.id_token else: # If the environment variable GOOGLE_APPLICATION_CREDENTIALS is set to service account JSON file, # then ID token is acquired using this service account credentials. id_token = fetch_id_token(auth_req, DRY_RUN_URL) return id_token class DryRunError(Exception): """Exception raised on dry run errors.""" def __init__(self, message, error, use_cloud_function, table_id): """Initialize DryRunError.""" super().__init__(message) self.error = error self.use_cloud_function = use_cloud_function self.table_id = table_id def __reduce__(self): """ Override to ensure that all parameters are being passed when pickling. Pickling happens when passing exception between processes (e.g. via multiprocessing) """ return ( self.__class__, self.args + (self.error, self.use_cloud_function, self.table_id), ) class Errors(Enum): """DryRun errors that require special handling.""" READ_ONLY = 1 DATE_FILTER_NEEDED = 2 DATE_FILTER_NEEDED_AND_SYNTAX = 3 PERMISSION_DENIED = 4 class DryRunContext: """DryRun builder class.""" def __init__( self, use_cloud_function=False, id_token=None, credentials=None, dry_run_url=DRY_RUN_URL, ): """Initialize dry run instance.""" self.use_cloud_function = use_cloud_function self.dry_run_url = dry_run_url self.id_token = id_token self.credentials = credentials def create( self, sql=None, project="moz-fx-data-shared-prod", dataset=None, table=None, ): """Initialize a DryRun instance.""" return DryRun( use_cloud_function=self.use_cloud_function, id_token=self.id_token, credentials=self.credentials, sql=sql, project=project, dataset=dataset, table=table, dry_run_url=self.dry_run_url, ) class DryRun: """Dry run SQL.""" def __init__( self, use_cloud_function=False, id_token=None, credentials=None, sql=None, project="moz-fx-data-shared-prod", dataset=None, table=None, dry_run_url=DRY_RUN_URL, ): """Initialize dry run instance.""" self.sql = sql self.use_cloud_function = use_cloud_function self.project = project self.dataset = dataset self.table = table self.dry_run_url = dry_run_url self.id_token = id_token self.credentials = credentials @cached_property def client(self): """Get BigQuery client instance.""" return bigquery.Client(credentials=self.credentials) @cached_property def dry_run_result(self): """Return the dry run result.""" try: if self.use_cloud_function: json_data = { "query": self.sql or "SELECT 1", "project": self.project, "dataset": self.dataset or "telemetry", } if self.table: json_data["table"] = self.table r = urlopen( Request( self.dry_run_url, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self.id_token}", }, data=json.dumps(json_data).encode("utf8"), method="POST", ) ) return json.load(r) else: query_schema = None referenced_tables = [] table_metadata = None if self.sql: job_config = bigquery.QueryJobConfig( dry_run=True, use_query_cache=False, query_parameters=[ bigquery.ScalarQueryParameter( "submission_date", "DATE", "2019-01-01" ) ], ) if self.project: job_config.connection_properties = [ bigquery.ConnectionProperty( "dataset_project_id", self.project ) ] job = self.client.query(self.sql, job_config=job_config) query_schema = ( job._properties.get("statistics", {}) .get("query", {}) .get("schema", {}) ) referenced_tables = [ ref.to_api_repr() for ref in job.referenced_tables ] if ( self.project is not None and self.table is not None and self.dataset is not None ): table = self.client.get_table( f"{self.project}.{self.dataset}.{self.table}" ) table_metadata = { "tableType": table.table_type, "friendlyName": table.friendly_name, "schema": { "fields": [field.to_api_repr() for field in table.schema] }, } return { "valid": True, "referencedTables": referenced_tables, "schema": query_schema, "tableMetadata": table_metadata, } except Exception as e: print(f"ERROR {e}") return None def get_schema(self): """Return the query schema by dry running the SQL file.""" self.validate() if ( self.dry_run_result and self.dry_run_result["valid"] and "schema" in self.dry_run_result ): return self.dry_run_result["schema"]["fields"] return [] def get_table_schema(self): """Return the schema of the provided table.""" self.validate() if ( self.dry_run_result and self.dry_run_result["valid"] and "tableMetadata" in self.dry_run_result ): return self.dry_run_result["tableMetadata"]["schema"]["fields"] return [] def get_table_metadata(self): """Return table metadata.""" self.validate() if ( self.dry_run_result and self.dry_run_result["valid"] and "tableMetadata" in self.dry_run_result ): return self.dry_run_result["tableMetadata"] return {} def validate(self): """Dry run the provided SQL file and check if valid.""" dry_run_error = DryRunError( "Error when dry running SQL", self.get_error(), self.use_cloud_function, self.table, ) if self.dry_run_result is None: raise dry_run_error if self.dry_run_result["valid"]: return True elif self.get_error() == Errors.READ_ONLY: # We want the dryrun service to only have read permissions, so # we expect CREATE VIEW and CREATE TABLE to throw specific # exceptions. return True elif self.get_error() == Errors.DATE_FILTER_NEEDED: # With strip_dml flag, some queries require a partition filter # (submission_date, submission_timestamp, etc.) to run return True else: print("ERROR\n", self.dry_run_result["errors"]) raise dry_run_error def errors(self): """Dry run the provided SQL file and return errors.""" if self.dry_run_result is None: return [] return self.dry_run_result.get("errors", []) def get_error(self) -> Optional[Errors]: """Get specific errors for edge case handling.""" errors = self.errors() if len(errors) != 1: return None error = errors[0] if error and error.get("code") in [400, 403]: error_message = error.get("message", "") if ( "does not have bigquery.tables.create permission for dataset" in error_message or "Permission bigquery.tables.create denied" in error_message or "Permission bigquery.datasets.update denied" in error_message ): return Errors.READ_ONLY if "without a filter over column(s)" in error_message: return Errors.DATE_FILTER_NEEDED if ( "Syntax error: Expected end of input but got keyword WHERE" in error_message ): return Errors.DATE_FILTER_NEEDED_AND_SYNTAX if ( "Permission bigquery.tables.get denied on table" in error_message or "User does not have permission to query table" in error_message ): return Errors.PERMISSION_DENIED return None