bigquery_etl/copy_deduplicate.py (327 lines of code) (raw):

""" Copy a day's data from live to stable ping tables, deduplicating on document_id. By default, the script will process all tables in datasets named like *_live, copying data into table of the same name in *_stable datasets. The script can be configured to exclude a list of tables or to process only a specific list of tables. """ import json import logging from datetime import datetime, timedelta from functools import partial from itertools import groupby from multiprocessing.pool import ThreadPool import click from google.api_core.exceptions import BadRequest from google.cloud import bigquery from bigquery_etl.cli.utils import table_matches_patterns from bigquery_etl.util.bigquery_id import sql_table_id from bigquery_etl.util.client_queue import ClientQueue from bigquery_etl.util.common import TempDatasetReference from .cli.utils import parallelism_option, project_id_option QUERY_TEMPLATE = """ WITH -- Distinct document_ids and their minimum submission_timestamp today -- not including document_ids that only occur on or after @end_time distinct_document_ids AS ( SELECT document_id, MIN(submission_timestamp) AS submission_timestamp FROM `{live_table}` WHERE DATE(submission_timestamp) >= DATE_SUB( DATE(@start_time), INTERVAL @num_preceding_days DAY ) AND submission_timestamp < @end_time -- Bug 1657360 AND 'automation' NOT IN ( SELECT TRIM(t) FROM UNNEST(SPLIT(metadata.header.x_source_tags, ',')) t ) GROUP BY document_id HAVING submission_timestamp >= @start_time), -- A single slice of a live ping table. base AS ( SELECT * FROM `{live_table}` JOIN distinct_document_ids -- Retain only the first seen documents for each ID, according to timestamp. USING (document_id, submission_timestamp) WHERE submission_timestamp >= @start_time AND submission_timestamp < @end_time), -- -- Order documents by assigning a row number. numbered_duplicates AS ( SELECT *, ROW_NUMBER() OVER (PARTITION BY document_id) AS _n FROM base) -- -- Retain only one document for each ID. SELECT * EXCEPT(_n) FROM numbered_duplicates WHERE _n = 1 """ def _get_query_job_configs( client, live_table, date, dry_run, slices, priority, preceding_days, num_retries, temp_dataset, ): sql = QUERY_TEMPLATE.format(live_table=live_table) stable_table = f"{live_table.replace('_live.', '_stable.', 1)}${date:%Y%m%d}" kwargs = dict(use_legacy_sql=False, dry_run=dry_run, priority=priority) start_time = datetime(*date.timetuple()[:6]) end_time = start_time + timedelta(days=1) if slices > 1: # create temporary tables with stable_table's time_partitioning and # clustering_fields, and a 1-day expiration stable_table = client.get_table(stable_table) ddl = "CREATE TABLE\n `{dest}`" ddl += f"\nPARTITION BY\n DATE({stable_table.time_partitioning.field})" if stable_table.clustering_fields: ddl += f"\nCLUSTER BY\n {', '.join(stable_table.clustering_fields)}" ddl += ( "\nOPTIONS" "\n (" "\n partition_expiration_days = CAST('inf' AS FLOAT64)," "\n expiration_timestamp = " "TIMESTAMP_ADD(CURRENT_TIMESTAMP, INTERVAL 1 DAY)" "\n )" ) slice_size = (end_time - start_time) / slices params = [start_time + slice_size * i for i in range(slices)] + [ end_time ] # explicitly use end_time to avoid rounding errors return [ ( f"{ddl.format(dest=temp_dataset.temp_table())}\nAS\n{sql.strip()}", stable_table, bigquery.QueryJobConfig( query_parameters=[ bigquery.ScalarQueryParameter( "start_time", "TIMESTAMP", params[i] ), bigquery.ScalarQueryParameter( "end_time", "TIMESTAMP", params[i + 1] ), bigquery.ScalarQueryParameter( "num_preceding_days", "INT64", preceding_days ), ], **kwargs, ), num_retries, ) for i in range(slices) ] else: return [ ( sql, stable_table, bigquery.QueryJobConfig( destination=stable_table, write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, query_parameters=[ bigquery.ScalarQueryParameter( "start_time", "TIMESTAMP", start_time ), bigquery.ScalarQueryParameter( "end_time", "TIMESTAMP", end_time ), bigquery.ScalarQueryParameter( "num_preceding_days", "INT64", preceding_days ), ], **kwargs, ), num_retries, ) ] def _run_deduplication_query(client, sql, stable_table, job_config, num_retries): query_job = client.query(sql, job_config, job_id_prefix="copy_dedup_") if not query_job.dry_run: try: query_job.result() except BadRequest as e: if num_retries <= 0: raise logging.warn("Encountered bad request, retrying: ", e) return _run_deduplication_query( client, sql, stable_table, job_config, num_retries - 1 ) logging.info( f"Completed query job {query_job.job_id} for {stable_table}" f" with params: {job_config.query_parameters}" ) return stable_table, query_job def _copy_join_parts(client, stable_table, query_jobs): total_bytes = sum(query.total_bytes_processed for query in query_jobs) if query_jobs[0].dry_run: api_repr = json.dumps(query_jobs[0].to_api_repr()) if len(query_jobs) > 1: logging.info(f"Would process {total_bytes} bytes: [{api_repr},...]") logging.info(f"Would copy {len(query_jobs)} results to {stable_table}") else: logging.info(f"Would process {total_bytes} bytes: {api_repr}") else: total_slot_hours = round( sum(query.slot_millis for query in query_jobs) / 1000 / 60 / 60, 3 ) logging.info( f"Processed {total_bytes} bytes in {total_slot_hours}" f" slot hours to populate {stable_table}" ) if len(query_jobs) > 1: partition_id = stable_table.table_id.split("$", 1)[1] sources = [ f"{sql_table_id(job.destination)}${partition_id}" for job in query_jobs ] job_config = bigquery.CopyJobConfig( write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE ) copy_job = client.copy_table(sources, stable_table, job_config=job_config) copy_job.result() logging.info(f"Copied {len(query_jobs)} results to populate {stable_table}") for job in query_jobs: client.delete_table(job.destination) logging.info(f"Deleted {len(query_jobs)} temporary tables") def _contains_glob(patterns): return any(set("*?[").intersection(pattern) for pattern in patterns) def _glob_dataset(pattern): return pattern.split(".", 1)[0] def _list_live_tables(client, pool, project_id, only_tables, table_filter): if only_tables and not _contains_glob(only_tables): # skip list calls when only_tables exists and contains no globs return [f"{project_id}.{t}" for t in only_tables if table_filter(t)] if only_tables and not _contains_glob(_glob_dataset(t) for t in only_tables): # skip list_datasets call when only_tables exists and datasets contain no globs live_datasets = {f"{project_id}.{_glob_dataset(t)}" for t in only_tables} else: live_datasets = [ d.reference for d in client.list_datasets(project_id) if d.dataset_id.endswith("_live") ] return [ sql_table_id(t) for tables in pool.map(client.list_tables, live_datasets) for t in tables if table_filter(f"{t.dataset_id}.{t.table_id}") and "beam_load_sink" not in t.table_id ] @click.command( "copy_deduplicate", help="Copy a day's data from live to stable ping tables, deduplicating on document_id", ) @project_id_option("moz-fx-data-shar-nonprod-efed") @click.option( "--dates", "--date", multiple=True, required=True, type=lambda d: datetime.strptime(d, "%Y-%m-%d").date(), help="One or more days of data to copy, in format 2019-01-01", ) @parallelism_option() @click.option( "--dry_run", "--dry-run", is_flag=True, help="Do not make changes, only log actions that would be taken", ) @click.option( "--log-level", "--log_level", help="Log level.", default=logging.getLevelName(logging.INFO), type=str.upper, ) @click.option( "--priority", default=bigquery.QueryPriority.INTERACTIVE, type=click.Choice( [bigquery.QueryPriority.BATCH, bigquery.QueryPriority.INTERACTIVE], case_sensitive=False, ), help="Priority for BigQuery query jobs; BATCH priority may significantly slow " "down queries if reserved slots are not enabled for the billing project; " "INTERACTIVE priority is limited to 100 concurrent queries per project", ) @click.option( "--temp-dataset", "--temp_dataset", "--temporary-dataset", "--temporary_dataset", default="moz-fx-data-shared-prod.tmp", type=TempDatasetReference.from_string, help="Dataset where intermediate query results will be temporarily stored, " "formatted as PROJECT_ID.DATASET_ID", ) @click.option( "--slices", type=int, default=1, help=( "Number of queries to split deduplicate over, each handling an equal-size time " "slice of the date; avoids memory overflow at the cost of less effective " "clustering; recommended only for tables failing due to memory overflow" ), ) @click.option( "--hourly", is_flag=True, help="Deduplicate one hour at a time; equivalent to --slices=24", ) @click.option( "--preceding_days", "--preceding-days", type=int, default=0, help="Number of days preceding --date that should be used to filter out duplicates", ) @click.option( "--num_retries", "--num-retries", type=int, default=2, help="Number of times to retry each slice in case of query error", ) @click.option( "--billing-projects", "--billing_projects", "--billing-project", "--billing_project", "-p", multiple=True, default=[None], help="One or more billing projects over which bigquery jobs should be " "distributed", ) @click.option( "--except", "-x", "exclude", multiple=True, help="Process all tables except for the given tables", ) @click.option( "--only", "-o", multiple=True, help="Process only the given tables", ) def copy_deduplicate( project_id, dates, parallelism, dry_run, log_level, priority, temp_dataset, slices, hourly, preceding_days, num_retries, billing_projects, exclude, only, ): """Copy a day's data from live to stable ping tables, dedup on document_id.""" # create a queue for balancing load across projects client_q = ClientQueue(billing_projects, parallelism) if hourly: slices = 24 table_filter = partial(table_matches_patterns, "*", False) if only: table_filter = partial(table_matches_patterns, list(only), False) elif exclude: table_filter = partial(table_matches_patterns, list(exclude), True) with ThreadPool(parallelism) as pool: with client_q.client() as client: live_tables = _list_live_tables( client=client, pool=pool, project_id=project_id, only_tables=only, table_filter=table_filter, ) query_jobs = [ (_run_deduplication_query, *args) for jobs in pool.starmap( _get_query_job_configs, [ ( client, # only use one client to create temp tables live_table, date, dry_run, slices, priority, preceding_days, num_retries, temp_dataset, ) for live_table in live_tables for date in dates ], ) for args in jobs ] # preserve query_jobs order so results stay sorted by stable_table for groupby results = pool.starmap(client_q.with_client, query_jobs, chunksize=1) copy_jobs = [ (_copy_join_parts, stable_table, [query_job for _, query_job in group]) for stable_table, group in groupby(results, key=lambda result: result[0]) ] pool.starmap(client_q.with_client, copy_jobs, chunksize=1)