def main()

in bigquery_etl/shredder/delete.py [0:0]


def main():
    """Process deletion requests."""
    args = parser.parse_args()
    if args.partition_limit is not None and not args.dry_run:
        parser.print_help()
        logging.warning("ERROR: --partition-limit specified without --dry-run")
    if len(args.sampling_tables) > 0 and args.temp_dataset is None:
        parser.error("--temp-dataset must be specified when using --sampling-tables")
    if args.start_date is None:
        args.start_date = args.end_date - timedelta(days=14)
    source_condition = (
        f"DATE(submission_timestamp) >= '{args.start_date}' "
        f"AND DATE(submission_timestamp) < '{args.end_date}'"
    )
    client_q = ClientQueue(
        args.billing_projects,
        args.parallelism,
        connection_pool_max_size=(
            max(args.parallelism * args.sampling_parallelism, 12)
            if len(args.sampling_tables) > 0
            else None
        ),
    )
    client = client_q.default_client
    states = {}
    if args.state_table:
        state_table_exists = False
        try:
            client.get_table(args.state_table)
            state_table_exists = True
        except NotFound:
            if not args.dry_run:
                client.create_table(
                    bigquery.Table(
                        args.state_table,
                        [
                            bigquery.SchemaField("task_id", "STRING"),
                            bigquery.SchemaField("job_id", "STRING"),
                            bigquery.SchemaField("job_created", "TIMESTAMP"),
                            bigquery.SchemaField("start_date", "DATE"),
                            bigquery.SchemaField("end_date", "DATE"),
                        ],
                    )
                )
                state_table_exists = True
        if state_table_exists:
            states = dict(
                client.query(
                    reformat(
                        f"""
                        SELECT
                          task_id,
                          job_id,
                        FROM
                          `{args.state_table}`
                        WHERE
                          end_date = '{args.end_date}'
                        ORDER BY
                          job_created
                        """
                    )
                ).result()
            )

    if args.environment == "telemetry":
        with ThreadPool(6) as pool:
            glean_targets = find_glean_targets(pool, client)
        targets_with_sources = (
            *DELETE_TARGETS.items(),
            *glean_targets.items(),
        )
    elif args.environment == "experiments":
        targets_with_sources = find_experiment_analysis_targets(client).items()
    elif args.environment == "pioneer":
        with ThreadPool(args.parallelism) as pool:
            targets_with_sources = find_pioneer_targets(
                pool, client, study_projects=args.pioneer_study_projects
            ).items()

    missing_sampling_tables = [
        t
        for t in args.sampling_tables
        if t not in [target.table for target, _ in targets_with_sources]
    ]
    if len(missing_sampling_tables) > 0:
        raise ValueError(
            f"{len(missing_sampling_tables)} sampling tables not found in "
            f"targets: {missing_sampling_tables}"
        )

    tasks = [
        task
        for target, sources in targets_with_sources
        if args.table_filter(target.table)
        for task in delete_from_table(
            client=client,
            target=replace(target, project=args.target_project or target.project),
            sources=[
                replace(source, project=args.source_project or source.project)
                for source in (sources if isinstance(sources, tuple) else (sources,))
            ],
            source_condition=source_condition,
            dry_run=args.dry_run,
            use_dml=args.use_dml,
            priority=args.priority,
            start_date=args.start_date,
            end_date=args.end_date,
            max_single_dml_bytes=args.max_single_dml_bytes,
            partition_limit=args.partition_limit,
            state_table=args.state_table,
            states=states,
            sampling_parallelism=args.sampling_parallelism,
            use_sampling=target.table in args.sampling_tables,
            temp_dataset=args.temp_dataset,
        )
    ]

    if not tasks:
        logging.error("No tables selected")
        parser.exit(1)
    # ORDER BY partition_sort_key DESC, sql_table_id ASC
    # https://docs.python.org/3/howto/sorting.html#sort-stability-and-complex-sorts
    tasks.sort(key=lambda task: sql_table_id(task.table))
    tasks.sort(key=attrgetter("partition_sort_key"), reverse=True)

    with ThreadPool(args.parallelism) as pool:
        if args.task_table and not args.dry_run:
            # record task information
            try:
                client.get_table(args.task_table)
            except NotFound:
                table = bigquery.Table(
                    args.task_table,
                    [
                        bigquery.SchemaField("task_id", "STRING"),
                        bigquery.SchemaField("start_date", "DATE"),
                        bigquery.SchemaField("end_date", "DATE"),
                        bigquery.SchemaField("target", "STRING"),
                        bigquery.SchemaField("target_rows", "INT64"),
                        bigquery.SchemaField("target_bytes", "INT64"),
                        bigquery.SchemaField("source_bytes", "INT64"),
                    ],
                )
                table.time_partitioning = bigquery.TimePartitioning()
                client.create_table(table)
            sources = list(set(source for task in tasks for source in task.sources))
            source_bytes = {
                source: job.total_bytes_processed
                for source, job in zip(
                    sources,
                    pool.starmap(
                        client.query,
                        [
                            (
                                reformat(
                                    f"""
                                    SELECT
                                      {source.field}
                                    FROM
                                      `{sql_table_id(source)}`
                                    WHERE
                                      {source_condition}
                                    """
                                ),
                                bigquery.QueryJobConfig(dry_run=True),
                            )
                            for source in sources
                        ],
                        chunksize=1,
                    ),
                )
            }
            step = 10000  # max 10K rows per insert
            for start in range(0, len(tasks), step):
                end = start + step
                BigQueryInsertError.raise_if_present(
                    errors=client.insert_rows_json(
                        args.task_table,
                        [
                            {
                                "task_id": get_task_id(task.table, task.partition_id),
                                "start_date": args.start_date.isoformat(),
                                "end_date": args.end_date.isoformat(),
                                "target": sql_table_id(task.table),
                                "target_rows": task.table.num_rows,
                                "target_bytes": task.table.num_bytes,
                                "source_bytes": sum(
                                    map(source_bytes.get, task.sources)
                                ),
                            }
                            for task in tasks[start:end]
                        ],
                    )
                )
        results = pool.map(
            client_q.with_client, (task.func for task in tasks), chunksize=1
        )
    jobs_by_table = defaultdict(list)
    for i, job in enumerate(results):
        jobs_by_table[tasks[i].table].append(job)
    bytes_processed = rows_deleted = 0
    for table, jobs in jobs_by_table.items():
        table_bytes_processed = sum(job.total_bytes_processed or 0 for job in jobs)
        bytes_processed += table_bytes_processed
        table_id = sql_table_id(table)
        if args.dry_run:
            logging.info(f"Would scan {table_bytes_processed} bytes from {table_id}")
        else:
            table_rows_deleted = sum(job.num_dml_affected_rows or 0 for job in jobs)
            rows_deleted += table_rows_deleted
            logging.info(
                f"Scanned {table_bytes_processed} bytes and "
                f"deleted {table_rows_deleted} rows from {table_id}"
            )
    if args.dry_run:
        logging.info(f"Would scan {bytes_processed} in total")
    else:
        logging.info(
            f"Scanned {bytes_processed} and deleted {rows_deleted} rows in total"
        )