bigquery_etl/dryrun.py (474 lines of code) (raw):

""" Dry run query files. Passes all queries to a Cloud Function that will run the queries with the dry_run option enabled. We could provision BigQuery credentials to the CircleCI job to allow it to run the queries directly, but there is no way to restrict permissions such that only dry runs can be performed. In order to reduce risk of CI or local users accidentally running queries during tests and overwriting production data, we proxy the queries through the dry run service endpoint. """ import glob import json import re import sys import time from enum import Enum from os.path import basename, dirname, exists from pathlib import Path from typing import Optional, Set from urllib.request import Request, urlopen import click import google.auth from google.auth.transport.requests import Request as GoogleAuthRequest from google.cloud import bigquery from google.oauth2.id_token import fetch_id_token from .config import ConfigLoader from .metadata.parse_metadata import Metadata from .util.common import render try: from functools import cached_property # type: ignore except ImportError: # python 3.7 compatibility from backports.cached_property import cached_property # type: ignore QUERY_PARAMETER_TYPE_VALUES = { "DATE": "2019-01-01", "DATETIME": "2019-01-01 00:00:00", "TIMESTAMP": "2019-01-01 00:00:00", "STRING": "foo", "BOOL": True, "FLOAT64": 1, "FLOAT": 1, "INT64": 1, "INTEGER": 1, "NUMERIC": 1, "BIGNUMERIC": 1, } def get_credentials(auth_req: Optional[GoogleAuthRequest] = None): """Get GCP credentials.""" auth_req = auth_req or GoogleAuthRequest() credentials, _ = google.auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"] ) credentials.refresh(auth_req) return credentials def get_id_token(dry_run_url=ConfigLoader.get("dry_run", "function"), credentials=None): """Get token to authenticate against Cloud Function.""" auth_req = GoogleAuthRequest() credentials = credentials or get_credentials(auth_req) if hasattr(credentials, "id_token"): # Get token from default credentials for the current environment created via Cloud SDK run id_token = credentials.id_token else: # If the environment variable GOOGLE_APPLICATION_CREDENTIALS is set to service account JSON file, # then ID token is acquired using this service account credentials. id_token = fetch_id_token(auth_req, dry_run_url) return id_token class Errors(Enum): """DryRun errors that require special handling.""" READ_ONLY = 1 DATE_FILTER_NEEDED = 2 DATE_FILTER_NEEDED_AND_SYNTAX = 3 class DryRun: """Dry run SQL files.""" def __init__( self, sqlfile, content=None, strip_dml=False, use_cloud_function=True, client=None, respect_skip=True, sql_dir=ConfigLoader.get("default", "sql_dir"), id_token=None, credentials=None, project=None, dataset=None, table=None, ): """Instantiate DryRun class.""" self.sqlfile = sqlfile self.content = content self.strip_dml = strip_dml self.use_cloud_function = use_cloud_function self.bq_client = client self.respect_skip = respect_skip self.dry_run_url = ConfigLoader.get("dry_run", "function") self.sql_dir = sql_dir self.id_token = ( id_token if not use_cloud_function or id_token else get_id_token(self.dry_run_url) ) self.credentials = credentials self.project = project self.dataset = dataset self.table = table try: self.metadata = Metadata.of_query_file(self.sqlfile) except FileNotFoundError: self.metadata = None self.dry_run_duration = None from bigquery_etl.cli.utils import is_authenticated if not is_authenticated(): print( "Authentication to GCP required. Run `gcloud auth login --update-adc` " "and check that the project is set correctly." ) sys.exit(1) @cached_property def client(self): """Get BigQuery client instance.""" if self.use_cloud_function: return None return self.bq_client or bigquery.Client(credentials=self.credentials) @staticmethod def skipped_files(sql_dir=ConfigLoader.get("default", "sql_dir")) -> Set[str]: """Return files skipped by dry run.""" default_sql_dir = Path(ConfigLoader.get("default", "sql_dir")) sql_dir = Path(sql_dir) file_pattern_re = re.compile(rf"^{re.escape(str(default_sql_dir))}/") skip_files = { file for skip in ConfigLoader.get("dry_run", "skip", fallback=[]) for file in glob.glob( file_pattern_re.sub(f"{str(sql_dir)}/", skip), recursive=True, ) } # update skip list to include renamed queries in stage. test_project = ConfigLoader.get("default", "test_project", fallback="") file_pattern_re = re.compile(r"sql/([^\/]+)/([^/]+)(/?.*|$)") skip_files.update( [ file for skip in ConfigLoader.get("dry_run", "skip", fallback=[]) for file in glob.glob( file_pattern_re.sub( lambda x: f"sql/{test_project}/{x.group(2)}_{x.group(1).replace('-', '_')}*{x.group(3)}", skip, ), recursive=True, ) ] ) return skip_files def skip(self): """Determine if dry run should be skipped.""" return self.respect_skip and self.sqlfile in self.skipped_files( sql_dir=self.sql_dir ) def get_sql(self): """Get SQL content.""" if exists(self.sqlfile): file_path = Path(self.sqlfile) sql = render( file_path.name, format=False, template_folder=file_path.parent.absolute(), ) else: raise ValueError(f"Invalid file path: {self.sqlfile}") if self.strip_dml: sql = re.sub( "CREATE OR REPLACE VIEW.*?AS", "", sql, flags=re.DOTALL, ) sql = re.sub( "CREATE MATERIALIZED VIEW.*?AS", "", sql, flags=re.DOTALL, ) return sql @cached_property def dry_run_result(self): """Dry run the provided SQL file.""" if self.content: sql = self.content else: sql = self.get_sql() query_parameters = [] scheduling_metadata = self.metadata.scheduling if self.metadata else {} if date_partition_parameter := scheduling_metadata.get( "date_partition_parameter", "submission_date" ): query_parameters.append( bigquery.ScalarQueryParameter( date_partition_parameter, "DATE", QUERY_PARAMETER_TYPE_VALUES["DATE"], ) ) for parameter in scheduling_metadata.get("parameters", []): parameter_name, parameter_type, _ = parameter.strip().split(":", 2) parameter_type = parameter_type.upper() or "STRING" query_parameters.append( bigquery.ScalarQueryParameter( parameter_name, parameter_type, QUERY_PARAMETER_TYPE_VALUES.get(parameter_type), ) ) project = basename(dirname(dirname(dirname(self.sqlfile)))) dataset = basename(dirname(dirname(self.sqlfile))) try: start_time = time.time() if self.use_cloud_function: json_data = { "project": self.project or project, "dataset": self.dataset or dataset, "query": sql, "query_parameters": [ query_parameter.to_api_repr() for query_parameter in query_parameters ], } if self.table: json_data["table"] = self.table r = urlopen( Request( self.dry_run_url, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self.id_token}", }, data=json.dumps(json_data).encode("utf8"), method="POST", ) ) result = json.load(r) else: self.client.project = project job_config = bigquery.QueryJobConfig( dry_run=True, use_query_cache=False, default_dataset=f"{project}.{dataset}", query_parameters=query_parameters, ) job = self.client.query(sql, job_config=job_config) try: dataset_labels = self.client.get_dataset(job.default_dataset).labels except Exception as e: # Most users do not have bigquery.datasets.get permission in # moz-fx-data-shared-prod # This should not prevent the dry run from running since the dataset # labels are usually not required if "Permission bigquery.datasets.get denied on dataset" in str(e): dataset_labels = [] else: raise e result = { "valid": True, "referencedTables": [ ref.to_api_repr() for ref in job.referenced_tables ], "schema": ( job._properties.get("statistics", {}) .get("query", {}) .get("schema", {}) ), "datasetLabels": dataset_labels, } if ( self.project is not None and self.table is not None and self.dataset is not None ): table = self.client.get_table( f"{self.project}.{self.dataset}.{self.table}" ) result["tableMetadata"] = { "tableType": table.table_type, "friendlyName": table.friendly_name, "schema": { "fields": [field.to_api_repr() for field in table.schema] }, } self.dry_run_duration = time.time() - start_time return result except Exception as e: print(f"{self.sqlfile!s:59} ERROR\n", e) return None def get_referenced_tables(self): """Return referenced tables by dry running the SQL file.""" if not self.skip() and not self.is_valid(): raise Exception(f"Error when dry running SQL file {self.sqlfile}") if self.skip(): print(f"\t...Ignoring dryrun results for {self.sqlfile}") if ( self.dry_run_result and self.dry_run_result["valid"] and "referencedTables" in self.dry_run_result ): return self.dry_run_result["referencedTables"] # Handle views that require a date filter if ( self.dry_run_result and self.strip_dml and self.get_error() == Errors.DATE_FILTER_NEEDED ): # Since different queries require different partition filters # (submission_date, crash_date, timestamp, submission_timestamp, ...) # We can extract the filter name from the error message # (by capturing the next word after "column(s)") # Example error: # "Cannot query over table <table_name> without a filter over column(s) # <date_filter_name> that can be used for partition elimination." error = self.dry_run_result["errors"][0].get("message", "") date_filter = find_next_word("column(s)", error) if "date" in date_filter: filtered_content = ( f"{self.get_sql()}WHERE {date_filter} > current_date()" ) if ( DryRun( self.sqlfile, filtered_content, client=self.client, id_token=self.id_token, ).get_error() == Errors.DATE_FILTER_NEEDED_AND_SYNTAX ): # If the date filter (e.g. WHERE crash_date > current_date()) # is added to a query that already has a WHERE clause, # it will throw an error. To fix this, we need to # append 'AND' instead of 'WHERE' filtered_content = ( f"{self.get_sql()}AND {date_filter} > current_date()" ) if "timestamp" in date_filter: filtered_content = ( f"{self.get_sql()}WHERE {date_filter} > current_timestamp()" ) if ( DryRun( sqlfile=self.sqlfile, content=filtered_content, client=self.client, id_token=self.id_token, ).get_error() == Errors.DATE_FILTER_NEEDED_AND_SYNTAX ): filtered_content = ( f"{self.get_sql()}AND {date_filter} > current_timestamp()" ) stripped_dml_result = DryRun( sqlfile=self.sqlfile, content=filtered_content, client=self.client, id_token=self.id_token, ) if ( stripped_dml_result.get_error() is None and "referencedTables" in stripped_dml_result.dry_run_result ): return stripped_dml_result.dry_run_result["referencedTables"] return [] def get_schema(self): """Return the query schema by dry running the SQL file.""" if not self.skip() and not self.is_valid(): raise Exception(f"Error when dry running SQL file {self.sqlfile}") if self.skip(): print(f"\t...Ignoring dryrun results for {self.sqlfile}") return {} if ( self.dry_run_result and self.dry_run_result["valid"] and "schema" in self.dry_run_result ): return self.dry_run_result["schema"] return {} def get_table_schema(self): """Return the schema of the provided table.""" if not self.skip() and not self.is_valid(): raise Exception(f"Error when dry running SQL file {self.sqlfile}") if self.skip(): print(f"\t...Ignoring dryrun results for {self.sqlfile}") return {} if ( self.dry_run_result and self.dry_run_result["valid"] and "tableMetadata" in self.dry_run_result ): return self.dry_run_result["tableMetadata"]["schema"] return [] def get_dataset_labels(self): """Return the labels on the default dataset by dry running the SQL file.""" if not self.skip() and not self.is_valid(): raise Exception(f"Error when dry running SQL file {self.sqlfile}") if self.skip(): print(f"\t...Ignoring dryrun results for {self.sqlfile}") return {} if ( self.dry_run_result and self.dry_run_result["valid"] and "datasetLabels" in self.dry_run_result ): return self.dry_run_result["datasetLabels"] return {} def is_valid(self): """Dry run the provided SQL file and check if valid.""" if self.dry_run_result is None: return False if self.dry_run_result["valid"]: print(f"{self.sqlfile!s:59} OK, took {self.dry_run_duration or 0:.2f}s") elif self.get_error() == Errors.READ_ONLY: # We want the dryrun service to only have read permissions, so # we expect CREATE VIEW and CREATE TABLE to throw specific # exceptions. print(f"{self.sqlfile!s:59} OK but DDL/DML skipped") elif self.get_error() == Errors.DATE_FILTER_NEEDED and self.strip_dml: # With strip_dml flag, some queries require a partition filter # (submission_date, submission_timestamp, etc.) to run # We mark these requests as valid and add a date filter # in get_referenced_table() print(f"{self.sqlfile!s:59} OK but DATE FILTER NEEDED") else: print(f"{self.sqlfile!s:59} ERROR\n", self.dry_run_result["errors"]) return False return True def errors(self): """Dry run the provided SQL file and return errors.""" if self.dry_run_result is None: return [] return self.dry_run_result.get("errors", []) def get_error(self) -> Optional[Errors]: """Get specific errors for edge case handling.""" errors = self.errors() if len(errors) != 1: return None error = errors[0] if error and error.get("code") in [400, 403]: error_message = error.get("message", "") if ( "does not have bigquery.tables.create permission for dataset" in error_message or "Permission bigquery.tables.create denied" in error_message or "Permission bigquery.datasets.update denied" in error_message ): return Errors.READ_ONLY if "without a filter over column(s)" in error_message: return Errors.DATE_FILTER_NEEDED if ( "Syntax error: Expected end of input but got keyword WHERE" in error_message ): return Errors.DATE_FILTER_NEEDED_AND_SYNTAX return None def validate_schema(self): """Check whether schema is valid.""" # delay import to prevent circular imports in 'bigquery_etl.schema' from .schema import SCHEMA_FILE, Schema if ( self.skip() or basename(self.sqlfile) == "script.sql" or str(self.sqlfile).endswith(".py") ): # noqa E501 print(f"\t...Ignoring schema validation for {self.sqlfile}") return True query_file_path = Path(self.sqlfile) query_schema = Schema.from_json(self.get_schema()) if self.errors(): # ignore file when there are errors that self.get_schema() did not raise click.echo(f"\t...Ignoring schema validation for {self.sqlfile}") return True existing_schema_path = query_file_path.parent / SCHEMA_FILE if not existing_schema_path.is_file(): click.echo(f"No schema file defined for {query_file_path}", err=True) return True table_name = query_file_path.parent.name dataset_name = query_file_path.parent.parent.name project_name = query_file_path.parent.parent.parent.name partitioned_by = None if ( self.metadata and self.metadata.bigquery and self.metadata.bigquery.time_partitioning ): partitioned_by = self.metadata.bigquery.time_partitioning.field table_schema = Schema.for_table( project_name, dataset_name, table_name, client=self.client, id_token=self.id_token, partitioned_by=partitioned_by, ) # This check relies on the new schema being deployed to prod if not query_schema.compatible(table_schema): click.echo( click.style( f"ERROR: Schema for query in {query_file_path} " f"incompatible with schema deployed for " f"{project_name}.{dataset_name}.{table_name}\n" f"Did you deploy new the schema to prod yet?", fg="red", ), err=True, ) return False else: existing_schema = Schema.from_schema_file(existing_schema_path) if not existing_schema.equal(query_schema): click.echo( click.style( f"ERROR: Schema defined in {existing_schema_path} " f"incompatible with query {query_file_path}", fg="red", ), err=True, ) return False click.echo(f"Schemas for {query_file_path} are valid.") return True def sql_file_valid(sqlfile): """Dry run SQL files.""" return DryRun(sqlfile).is_valid() def find_next_word(target, source): """Find the next word in a string.""" split = source.split() for i, w in enumerate(split): if w == target: # get the next word, and remove quotations from column name return split[i + 1].replace("'", "")