bigquery_etl/shredder/delete.py (710 lines of code) (raw):

"""Delete user data from long term storage.""" import logging import warnings from argparse import ArgumentParser from collections import defaultdict from dataclasses import dataclass, replace from datetime import datetime, timedelta from functools import partial from multiprocessing.pool import ThreadPool from operator import attrgetter from textwrap import dedent from typing import Callable, Iterable, Optional, Tuple, Union from google.api_core.exceptions import NotFound from google.cloud import bigquery from google.cloud.bigquery import QueryJob from ..format_sql.formatter import reformat from ..util import standard_args from ..util.bigquery_id import FULL_JOB_ID_RE, full_job_id, sql_table_id from ..util.client_queue import ClientQueue from ..util.exceptions import BigQueryInsertError from .config import ( DELETE_TARGETS, DeleteSource, DeleteTarget, find_experiment_analysis_targets, find_glean_targets, find_pioneer_targets, ) NULL_PARTITION_ID = "__NULL__" OUTSIDE_RANGE_PARTITION_ID = "__UNPARTITIONED__" parser = ArgumentParser(description=__doc__) standard_args.add_dry_run(parser) parser.add_argument( "--environment", default="telemetry", const="telemetry", nargs="?", choices=["telemetry", "pioneer", "experiments"], help="environment to run in (dictates the choice of source and target tables): " "telemetry - standard environment, " "pioneer - restricted pioneer environment, " "experiments - experiment analysis tables", ) parser.add_argument( "--pioneer-study-projects", "--pioneer_study_projects", default=[], help="Pioneer study-specific analysis projects to include in data deletion.", type=lambda s: [i for i in s.split(",")], ) parser.add_argument( "--partition-limit", "--partition_limit", metavar="N", type=int, help="Only use the first N partitions per table; requires --dry-run", ) parser.add_argument( "--no-use-dml", "--no_use_dml", action="store_false", dest="use_dml", help="Use SELECT * FROM instead of DELETE in queries to avoid concurrent DML limit " "or errors due to read-only permissions being insufficient to dry run DML; unless " "used with --dry-run, DML will still be used for special partitions like __NULL__", ) standard_args.add_log_level(parser) standard_args.add_parallelism(parser) parser.add_argument( "-e", "--end-date", "--end_date", default=datetime.utcnow().date(), type=lambda x: datetime.strptime(x, "%Y-%m-%d").date(), help="last date of pings to delete; One day after last date of " "deletion requests to process; defaults to today in UTC", ) parser.add_argument( "-s", "--start-date", "--start_date", type=lambda x: datetime.strptime(x, "%Y-%m-%d").date(), help="first date of deletion requests to process; DOES NOT apply to ping date; " "defaults to 14 days before --end-date in UTC", ) standard_args.add_billing_projects(parser, default=["moz-fx-data-bq-batch-prod"]) parser.add_argument( "--source-project", "--source_project", help="override the project used for deletion request tables", ) parser.add_argument( "--target-project", "--target_project", help="override the project used for target tables", ) parser.add_argument( "--max-single-dml-bytes", "--max_single_dml_bytes", default=10 * 2**40, type=int, help="Maximum number of bytes in a table that should be processed using a single " "DML query; tables above this limit will be processed using per-partition " "queries; this option prevents queries against large tables from exceeding the " "6-hour time limit; defaults to 10 TiB", ) standard_args.add_priority(parser) parser.add_argument( "--state-table", "--state_table", metavar="TABLE", help="Table for recording state; Used to avoid repeating deletes if interrupted; " "Create it if it does not exist; By default state is not recorded", ) parser.add_argument( "--task-table", "--task_table", metavar="TABLE", help="Table for recording tasks; Used along with --state-table to determine " "progress; Create it if it does not exist; By default tasks are not recorded", ) standard_args.add_table_filter(parser) parser.add_argument( "--sampling-tables", "--sampling_tables", nargs="+", dest="sampling_tables", help="Create tasks per sample id for the given table(s). Table format is dataset.table_name.", default=[], ) parser.add_argument( "--sampling-parallelism", "--sampling_parallelism", type=int, default=10, help="Number of concurrent queries to run per partition when shredding per sample id", ) parser.add_argument( "--temp-dataset", "--temp_dataset", metavar="PROJECT.DATASET", help="Dataset (project.dataset format) to write intermediate results of sampled queries to. " "Must be specified when --sampling-tables is set.", ) @dataclass class DeleteJobResults: """Subset of a bigquery job object retaining only fields that are needed.""" job_id: str total_bytes_processed: Optional[int] num_dml_affected_rows: Optional[int] destination: str def record_state(client, task_id, job, dry_run, start_date, end_date, state_table): """Record the job for task_id in state_table.""" if state_table is not None: job_id = "a job_id" if dry_run else full_job_id(job) insert_tense = "Would insert" if dry_run else "Inserting" logging.info(f"{insert_tense} {job_id} in {state_table} for task: {task_id}") if not dry_run: BigQueryInsertError.raise_if_present( errors=client.insert_rows_json( state_table, [ { "task_id": task_id, "job_id": job_id, "job_created": job.created.isoformat(), "start_date": start_date.isoformat(), "end_date": end_date.isoformat(), } ], skip_invalid_rows=False, ) ) def wait_for_job( client, states, task_id, dry_run, create_job, check_table_existence=False, **state_kwargs, ) -> DeleteJobResults: """Get a job from state or create a new job, and wait for the job to complete.""" job = None if task_id in states: job = client.get_job(**FULL_JOB_ID_RE.fullmatch(states[task_id]).groupdict()) # type: ignore[union-attr] if job.errors: logging.info(f"Previous attempt failed, retrying for {task_id}") job = None elif job.ended: # if destination table no longer exists (temp table expired), rerun job try: if check_table_existence: client.get_table(job.destination) logging.info( f"Previous attempt succeeded, reusing result for {task_id}" ) except NotFound: logging.info(f"Previous result expired, retrying for {task_id}") job = None else: logging.info(f"Previous attempt still running for {task_id}") if job is None: job = create_job(client) record_state( client=client, task_id=task_id, dry_run=dry_run, job=job, **state_kwargs ) if not dry_run and not job.ended: logging.info(f"Waiting on {full_job_id(job)} for {task_id}") job.result() try: bytes_processed = job.total_bytes_processed except AttributeError: bytes_processed = 0 return DeleteJobResults( job_id=job.job_id, total_bytes_processed=bytes_processed, num_dml_affected_rows=( job.num_dml_affected_rows if isinstance(job, QueryJob) else None ), destination=job.destination, ) def get_task_id(target, partition_id): """Get unique task id for state tracking.""" task_id = sql_table_id(target) if partition_id: task_id += f"${partition_id}" return task_id @dataclass class Partition: """Return type for get_partition.""" condition: str id: Optional[str] = None is_special: bool = False def _override_query_with_fxa_id_in_extras( field_condition: str, target: DeleteTarget, sources: Iterable[DeleteSource], source_condition: str, ) -> str: """Override query to handle fxa_id nested in event extras in relay_backend_stable.events_v1.""" sources = list(sources) if ( target.table == "relay_backend_stable.events_v1" and len(target.fields) == 1 and target.fields[0] == "events[*].extra.fxa_id" and len(sources) == 1 and sources[0].table == "firefox_accounts.fxa_delete_events" and sources[0].field == "user_id" ): field_condition = ( """ EXISTS ( WITH user_ids AS ( SELECT user_id_unhashed AS user_id FROM `moz-fx-data-shared-prod.firefox_accounts.fxa_delete_events` WHERE """ + " AND ".join((source_condition, *sources[0].conditions)) + """) SELECT 1 FROM UNNEST(events) AS e JOIN UNNEST(e.extra) AS ex JOIN user_ids u ON ex.value = u.user_id WHERE ex.key = 'fxa_id' ) """ ) return field_condition def delete_from_partition( dry_run: bool, partition: Partition, priority: str, source_condition: str, sources: Iterable[DeleteSource], target: DeleteTarget, use_dml: bool, sample_id: Optional[int] = None, temp_dataset: Optional[str] = None, clustering_fields: Optional[Iterable[str]] = None, **wait_for_job_kwargs, ): """Return callable to handle deletion requests for partitions of a target table.""" job_config = bigquery.QueryJobConfig(dry_run=dry_run, priority=priority) # whole table operations must use DML to protect against dropping partitions in the # case of conflicting write operations in ETL, and special partitions must use DML # because they can't be set as a query destination. if partition.id is None or partition.is_special: use_dml = True elif sample_id is not None: use_dml = False job_config.destination = ( f"{temp_dataset}.{target.table_id}_{partition.id}__sample_{sample_id}" ) job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE job_config.clustering_fields = clustering_fields elif not use_dml: job_config.destination = f"{sql_table_id(target)}${partition.id}" job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE def create_job(client) -> bigquery.QueryJob: if use_dml: field_condition = " OR ".join( f""" {field} IN ( SELECT {source.field} FROM `{sql_table_id(source)}` WHERE """ + " AND ".join((source_condition, *source.conditions)) + ")" for field, source in zip(target.fields, sources) ) # Temporary workaround for fxa_id nested in event extras in relay_backend_stable.events_v1 # We'll be able to remove this once fxa_id is migrated to string metric # See https://mozilla-hub.atlassian.net/browse/DENG-7965 and 7964 field_condition = _override_query_with_fxa_id_in_extras( field_condition, target, sources, source_condition ) query = reformat( f""" DELETE `{sql_table_id(target)}` WHERE ({field_condition}) AND ({partition.condition}) """ ) else: field_joins = "".join( ( f""" LEFT JOIN ( SELECT {source.field} AS _source_{index} FROM `{sql_table_id(source)}` WHERE """ + " AND ".join((source_condition, *source.conditions)) + (f" AND sample_id = {sample_id}" if sample_id is not None else "") + f""" ) ON {field} = _source_{index} """ ) for index, (field, source) in enumerate(zip(target.fields, sources)) ) field_conditions = " AND ".join( f"_source_{index} IS NULL" for index, _ in enumerate(sources) ) if partition.id is None: # only apply field conditions on partition.condition field_conditions = f""" ({partition.condition}) IS NOT TRUE OR ({field_conditions}) """ # always true partition condition to satisfy require_partition_filter partition_condition = f""" ({partition.condition}) IS NOT TRUE OR ({partition.condition}) """ else: partition_condition = partition.condition query = reformat( f""" SELECT _target.*, FROM `{sql_table_id(target)}` AS _target {field_joins} WHERE ({field_conditions}) AND ({partition_condition}) {f" AND sample_id = {sample_id}" if sample_id is not None else ""} """ ) run_tense = "Would run" if dry_run else "Running" logging.debug(f"{run_tense} query: {query}") return client.query(query, job_config=job_config) return partial( wait_for_job, create_job=create_job, dry_run=dry_run, **wait_for_job_kwargs ) def delete_from_partition_with_sampling( dry_run: bool, partition: Partition, priority: str, source_condition: str, sources: Iterable[DeleteSource], target: DeleteTarget, use_dml: bool, sampling_parallelism: int, temp_dataset: str, **wait_for_job_kwargs, ): """Return callable to delete from a partition of a target table per sample id.""" copy_job_config = bigquery.CopyJobConfig( write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, create_disposition=bigquery.CreateDisposition.CREATE_IF_NEEDED, ) target_table = f"{sql_table_id(target)}${partition.id}" def delete_by_sample(client) -> Union[bigquery.CopyJob, bigquery.QueryJob]: intermediate_clustering_fields = client.get_table( target_table ).clustering_fields tasks = [ delete_from_partition( dry_run=dry_run, partition=partition, priority=priority, source_condition=source_condition, sources=sources, target=target, use_dml=use_dml, temp_dataset=temp_dataset, sample_id=s, clustering_fields=intermediate_clustering_fields, check_table_existence=True, **{ **wait_for_job_kwargs, # override task id with sample id suffix "task_id": f"{wait_for_job_kwargs['task_id']}__sample_{s}", }, ) for s in range(100) ] # Run all 100 delete functions in parallel, exception is raised without retry if any fail with ThreadPool(sampling_parallelism) as pool: jobs = [ pool.apply_async( task, args=(client,), ) for task in tasks ] results = [job.get() for job in jobs] intermediate_tables = [result.destination for result in results] run_tense = "Would copy" if dry_run else "Copying" logging.debug( f"{run_tense} {len(intermediate_tables)} " f"{[str(t) for t in intermediate_tables]} to {target_table}" ) if not dry_run: copy_job = client.copy_table( sources=intermediate_tables, destination=target_table, job_config=copy_job_config, ) # simulate query job properties for logging copy_job.total_bytes_processed = sum( [r.total_bytes_processed for r in results] ) return copy_job else: # copy job doesn't have dry runs so dry run base partition for byte estimate return client.query( f"SELECT * FROM `{sql_table_id(target)}` WHERE {partition.condition}", job_config=bigquery.QueryJobConfig(dry_run=True), ) return partial( wait_for_job, create_job=delete_by_sample, dry_run=dry_run, **wait_for_job_kwargs, ) def get_partition_expr(table): """Get the SQL expression to use for a table's partitioning field.""" if table.range_partitioning: return table.range_partitioning.field if table.time_partitioning: return f"CAST({table.time_partitioning.field or '_PARTITIONTIME'} AS DATE)" def get_partition(table, partition_expr, end_date, id_=None) -> Optional[Partition]: """Return a Partition for id_ unless it is a date on or after end_date.""" if id_ is None: if table.time_partitioning: return Partition(condition=f"{partition_expr} < '{end_date}'") return Partition(condition="TRUE") if id_ == NULL_PARTITION_ID: if table.time_partitioning: return Partition( condition=f"{table.time_partitioning.field} IS NULL", id=id_, is_special=True, ) return Partition(condition=f"{partition_expr} IS NULL", id=id_, is_special=True) if table.time_partitioning: date = datetime.strptime(id_, "%Y%m%d").date() if date < end_date: return Partition(f"{partition_expr} = '{date}'", id_) return None if table.range_partitioning: if id_ == OUTSIDE_RANGE_PARTITION_ID: return Partition( condition=f"{partition_expr} < {table.range_partitioning.range_.start} " f"OR {partition_expr} >= {table.range_partitioning.range_.end}", id=id_, is_special=True, ) if table.range_partitioning.range_.interval > 1: return Partition( condition=f"{partition_expr} BETWEEN {id_} " f"AND {int(id_) + table.range_partitioning.range_.interval - 1}", id=id_, ) return Partition(condition=f"{partition_expr} = {id_}", id=id_) def list_partitions( client, table, partition_expr, end_date, max_single_dml_bytes, partition_limit ): """List the relevant partitions in a table.""" partitions = [ partition for partition in ( [ get_partition(table, partition_expr, end_date, row["partition_id"]) for row in client.query( dedent( f""" SELECT partition_id FROM [{sql_table_id(table)}$__PARTITIONS_SUMMARY__] """ ).strip(), bigquery.QueryJobConfig(use_legacy_sql=True), ).result() ] if table.num_bytes > max_single_dml_bytes and partition_expr is not None else [get_partition(table, partition_expr, end_date)] ) if partition is not None ] if partition_limit: return sorted(partitions, key=attrgetter("id"), reverse=True)[:partition_limit] return partitions @dataclass class Task: """Return type for delete_from_table.""" table: bigquery.Table sources: Tuple[DeleteSource] partition_id: Optional[str] func: Callable[[bigquery.Client], bigquery.QueryJob] @property def partition_sort_key(self): """Return a tuple to control the order in which tasks will be handled. When used with reverse=True, handle tasks without partition_id, then tasks without time_partitioning, then most recent dates first. """ return ( self.partition_id is None, self.table.time_partitioning is None, self.partition_id, ) def delete_from_table( client, target, sources, dry_run, end_date, max_single_dml_bytes, partition_limit, sampling_parallelism, use_sampling, temp_dataset, **kwargs, ) -> Iterable[Task]: """Yield tasks to handle deletion requests for a target table.""" try: table = client.get_table(sql_table_id(target)) except NotFound: logging.warning(f"Skipping {sql_table_id(target)} due to NotFound exception") return () # type: ignore partition_expr = get_partition_expr(table) for partition in list_partitions( client, table, partition_expr, end_date, max_single_dml_bytes, partition_limit ): # no sampling for __NULL__ partition if use_sampling and not partition.is_special: kwargs["sampling_parallelism"] = sampling_parallelism delete_func: Callable = delete_from_partition_with_sampling else: if use_sampling: logging.warning( "Cannot use sampling on full table deletion, " f"{target.dataset_id}.{target.table_id} is too small to use sampling" ) kwargs.pop("sampling_parallelism", None) delete_func = delete_from_partition yield Task( table=table, sources=sources, partition_id=partition.id, func=delete_func( dry_run=dry_run, partition=partition, target=target, sources=sources, task_id=get_task_id(target, partition.id), end_date=end_date, temp_dataset=temp_dataset, **kwargs, ), ) def main(): """Process deletion requests.""" args = parser.parse_args() if args.partition_limit is not None and not args.dry_run: parser.print_help() logging.warning("ERROR: --partition-limit specified without --dry-run") if len(args.sampling_tables) > 0 and args.temp_dataset is None: parser.error("--temp-dataset must be specified when using --sampling-tables") if args.start_date is None: args.start_date = args.end_date - timedelta(days=14) source_condition = ( f"DATE(submission_timestamp) >= '{args.start_date}' " f"AND DATE(submission_timestamp) < '{args.end_date}'" ) client_q = ClientQueue( args.billing_projects, args.parallelism, connection_pool_max_size=( max(args.parallelism * args.sampling_parallelism, 12) if len(args.sampling_tables) > 0 else None ), ) client = client_q.default_client states = {} if args.state_table: state_table_exists = False try: client.get_table(args.state_table) state_table_exists = True except NotFound: if not args.dry_run: client.create_table( bigquery.Table( args.state_table, [ bigquery.SchemaField("task_id", "STRING"), bigquery.SchemaField("job_id", "STRING"), bigquery.SchemaField("job_created", "TIMESTAMP"), bigquery.SchemaField("start_date", "DATE"), bigquery.SchemaField("end_date", "DATE"), ], ) ) state_table_exists = True if state_table_exists: states = dict( client.query( reformat( f""" SELECT task_id, job_id, FROM `{args.state_table}` WHERE end_date = '{args.end_date}' ORDER BY job_created """ ) ).result() ) if args.environment == "telemetry": with ThreadPool(6) as pool: glean_targets = find_glean_targets(pool, client) targets_with_sources = ( *DELETE_TARGETS.items(), *glean_targets.items(), ) elif args.environment == "experiments": targets_with_sources = find_experiment_analysis_targets(client).items() elif args.environment == "pioneer": with ThreadPool(args.parallelism) as pool: targets_with_sources = find_pioneer_targets( pool, client, study_projects=args.pioneer_study_projects ).items() missing_sampling_tables = [ t for t in args.sampling_tables if t not in [target.table for target, _ in targets_with_sources] ] if len(missing_sampling_tables) > 0: raise ValueError( f"{len(missing_sampling_tables)} sampling tables not found in " f"targets: {missing_sampling_tables}" ) tasks = [ task for target, sources in targets_with_sources if args.table_filter(target.table) for task in delete_from_table( client=client, target=replace(target, project=args.target_project or target.project), sources=[ replace(source, project=args.source_project or source.project) for source in (sources if isinstance(sources, tuple) else (sources,)) ], source_condition=source_condition, dry_run=args.dry_run, use_dml=args.use_dml, priority=args.priority, start_date=args.start_date, end_date=args.end_date, max_single_dml_bytes=args.max_single_dml_bytes, partition_limit=args.partition_limit, state_table=args.state_table, states=states, sampling_parallelism=args.sampling_parallelism, use_sampling=target.table in args.sampling_tables, temp_dataset=args.temp_dataset, ) ] if not tasks: logging.error("No tables selected") parser.exit(1) # ORDER BY partition_sort_key DESC, sql_table_id ASC # https://docs.python.org/3/howto/sorting.html#sort-stability-and-complex-sorts tasks.sort(key=lambda task: sql_table_id(task.table)) tasks.sort(key=attrgetter("partition_sort_key"), reverse=True) with ThreadPool(args.parallelism) as pool: if args.task_table and not args.dry_run: # record task information try: client.get_table(args.task_table) except NotFound: table = bigquery.Table( args.task_table, [ bigquery.SchemaField("task_id", "STRING"), bigquery.SchemaField("start_date", "DATE"), bigquery.SchemaField("end_date", "DATE"), bigquery.SchemaField("target", "STRING"), bigquery.SchemaField("target_rows", "INT64"), bigquery.SchemaField("target_bytes", "INT64"), bigquery.SchemaField("source_bytes", "INT64"), ], ) table.time_partitioning = bigquery.TimePartitioning() client.create_table(table) sources = list(set(source for task in tasks for source in task.sources)) source_bytes = { source: job.total_bytes_processed for source, job in zip( sources, pool.starmap( client.query, [ ( reformat( f""" SELECT {source.field} FROM `{sql_table_id(source)}` WHERE {source_condition} """ ), bigquery.QueryJobConfig(dry_run=True), ) for source in sources ], chunksize=1, ), ) } step = 10000 # max 10K rows per insert for start in range(0, len(tasks), step): end = start + step BigQueryInsertError.raise_if_present( errors=client.insert_rows_json( args.task_table, [ { "task_id": get_task_id(task.table, task.partition_id), "start_date": args.start_date.isoformat(), "end_date": args.end_date.isoformat(), "target": sql_table_id(task.table), "target_rows": task.table.num_rows, "target_bytes": task.table.num_bytes, "source_bytes": sum( map(source_bytes.get, task.sources) ), } for task in tasks[start:end] ], ) ) results = pool.map( client_q.with_client, (task.func for task in tasks), chunksize=1 ) jobs_by_table = defaultdict(list) for i, job in enumerate(results): jobs_by_table[tasks[i].table].append(job) bytes_processed = rows_deleted = 0 for table, jobs in jobs_by_table.items(): table_bytes_processed = sum(job.total_bytes_processed or 0 for job in jobs) bytes_processed += table_bytes_processed table_id = sql_table_id(table) if args.dry_run: logging.info(f"Would scan {table_bytes_processed} bytes from {table_id}") else: table_rows_deleted = sum(job.num_dml_affected_rows or 0 for job in jobs) rows_deleted += table_rows_deleted logging.info( f"Scanned {table_bytes_processed} bytes and " f"deleted {table_rows_deleted} rows from {table_id}" ) if args.dry_run: logging.info(f"Would scan {bytes_processed} in total") else: logging.info( f"Scanned {bytes_processed} and deleted {rows_deleted} rows in total" ) if __name__ == "__main__": warnings.filterwarnings("ignore", module="google.auth._default") main()