"""Delete user data from long term storage."""

import logging
import warnings
from argparse import ArgumentParser
from collections import defaultdict
from dataclasses import dataclass, replace
from datetime import datetime, timedelta
from functools import partial
from multiprocessing.pool import ThreadPool
from operator import attrgetter
from textwrap import dedent
from typing import Callable, Iterable, Optional, Tuple, Union

from google.api_core.exceptions import NotFound
from google.cloud import bigquery
from google.cloud.bigquery import QueryJob

from ..format_sql.formatter import reformat
from ..util import standard_args
from ..util.bigquery_id import FULL_JOB_ID_RE, full_job_id, sql_table_id
from ..util.client_queue import ClientQueue
from ..util.exceptions import BigQueryInsertError
from .config import (
    DELETE_TARGETS,
    DeleteSource,
    DeleteTarget,
    find_experiment_analysis_targets,
    find_glean_targets,
    find_pioneer_targets,
)

NULL_PARTITION_ID = "__NULL__"
OUTSIDE_RANGE_PARTITION_ID = "__UNPARTITIONED__"

parser = ArgumentParser(description=__doc__)
standard_args.add_dry_run(parser)
parser.add_argument(
    "--environment",
    default="telemetry",
    const="telemetry",
    nargs="?",
    choices=["telemetry", "pioneer", "experiments"],
    help="environment to run in (dictates the choice of source and target tables): "
    "telemetry - standard environment, "
    "pioneer - restricted pioneer environment, "
    "experiments - experiment analysis tables",
)
parser.add_argument(
    "--pioneer-study-projects",
    "--pioneer_study_projects",
    default=[],
    help="Pioneer study-specific analysis projects to include in data deletion.",
    type=lambda s: [i for i in s.split(",")],
)
parser.add_argument(
    "--partition-limit",
    "--partition_limit",
    metavar="N",
    type=int,
    help="Only use the first N partitions per table; requires --dry-run",
)
parser.add_argument(
    "--no-use-dml",
    "--no_use_dml",
    action="store_false",
    dest="use_dml",
    help="Use SELECT * FROM instead of DELETE in queries to avoid concurrent DML limit "
    "or errors due to read-only permissions being insufficient to dry run DML; unless "
    "used with --dry-run, DML will still be used for special partitions like __NULL__",
)
standard_args.add_log_level(parser)
standard_args.add_parallelism(parser)
parser.add_argument(
    "-e",
    "--end-date",
    "--end_date",
    default=datetime.utcnow().date(),
    type=lambda x: datetime.strptime(x, "%Y-%m-%d").date(),
    help="last date of pings to delete; One day after last date of "
    "deletion requests to process; defaults to today in UTC",
)
parser.add_argument(
    "-s",
    "--start-date",
    "--start_date",
    type=lambda x: datetime.strptime(x, "%Y-%m-%d").date(),
    help="first date of deletion requests to process; DOES NOT apply to ping date; "
    "defaults to 14 days before --end-date in UTC",
)
standard_args.add_billing_projects(parser, default=["moz-fx-data-bq-batch-prod"])
parser.add_argument(
    "--source-project",
    "--source_project",
    help="override the project used for deletion request tables",
)
parser.add_argument(
    "--target-project",
    "--target_project",
    help="override the project used for target tables",
)
parser.add_argument(
    "--max-single-dml-bytes",
    "--max_single_dml_bytes",
    default=10 * 2**40,
    type=int,
    help="Maximum number of bytes in a table that should be processed using a single "
    "DML query; tables above this limit will be processed using per-partition "
    "queries; this option prevents queries against large tables from exceeding the "
    "6-hour time limit; defaults to 10 TiB",
)
standard_args.add_priority(parser)
parser.add_argument(
    "--state-table",
    "--state_table",
    metavar="TABLE",
    help="Table for recording state; Used to avoid repeating deletes if interrupted; "
    "Create it if it does not exist; By default state is not recorded",
)
parser.add_argument(
    "--task-table",
    "--task_table",
    metavar="TABLE",
    help="Table for recording tasks; Used along with --state-table to determine "
    "progress; Create it if it does not exist; By default tasks are not recorded",
)
standard_args.add_table_filter(parser)
parser.add_argument(
    "--sampling-tables",
    "--sampling_tables",
    nargs="+",
    dest="sampling_tables",
    help="Create tasks per sample id for the given table(s).  Table format is dataset.table_name.",
    default=[],
)
parser.add_argument(
    "--sampling-parallelism",
    "--sampling_parallelism",
    type=int,
    default=10,
    help="Number of concurrent queries to run per partition when shredding per sample id",
)
parser.add_argument(
    "--temp-dataset",
    "--temp_dataset",
    metavar="PROJECT.DATASET",
    help="Dataset (project.dataset format) to write intermediate results of sampled queries to. "
    "Must be specified when --sampling-tables is set.",
)


@dataclass
class DeleteJobResults:
    """Subset of a bigquery job object retaining only fields that are needed."""

    job_id: str
    total_bytes_processed: Optional[int]
    num_dml_affected_rows: Optional[int]
    destination: str


def record_state(client, task_id, job, dry_run, start_date, end_date, state_table):
    """Record the job for task_id in state_table."""
    if state_table is not None:
        job_id = "a job_id" if dry_run else full_job_id(job)
        insert_tense = "Would insert" if dry_run else "Inserting"
        logging.info(f"{insert_tense} {job_id} in {state_table} for task: {task_id}")
        if not dry_run:
            BigQueryInsertError.raise_if_present(
                errors=client.insert_rows_json(
                    state_table,
                    [
                        {
                            "task_id": task_id,
                            "job_id": job_id,
                            "job_created": job.created.isoformat(),
                            "start_date": start_date.isoformat(),
                            "end_date": end_date.isoformat(),
                        }
                    ],
                    skip_invalid_rows=False,
                )
            )


def wait_for_job(
    client,
    states,
    task_id,
    dry_run,
    create_job,
    check_table_existence=False,
    **state_kwargs,
) -> DeleteJobResults:
    """Get a job from state or create a new job, and wait for the job to complete."""
    job = None
    if task_id in states:
        job = client.get_job(**FULL_JOB_ID_RE.fullmatch(states[task_id]).groupdict())  # type: ignore[union-attr]
        if job.errors:
            logging.info(f"Previous attempt failed, retrying for {task_id}")
            job = None
        elif job.ended:
            # if destination table no longer exists (temp table expired), rerun job
            try:
                if check_table_existence:
                    client.get_table(job.destination)
                logging.info(
                    f"Previous attempt succeeded, reusing result for {task_id}"
                )
            except NotFound:
                logging.info(f"Previous result expired, retrying for {task_id}")
                job = None
        else:
            logging.info(f"Previous attempt still running for {task_id}")
    if job is None:
        job = create_job(client)
        record_state(
            client=client, task_id=task_id, dry_run=dry_run, job=job, **state_kwargs
        )
    if not dry_run and not job.ended:
        logging.info(f"Waiting on {full_job_id(job)} for {task_id}")
        job.result()

    try:
        bytes_processed = job.total_bytes_processed
    except AttributeError:
        bytes_processed = 0

    return DeleteJobResults(
        job_id=job.job_id,
        total_bytes_processed=bytes_processed,
        num_dml_affected_rows=(
            job.num_dml_affected_rows if isinstance(job, QueryJob) else None
        ),
        destination=job.destination,
    )


def get_task_id(target, partition_id):
    """Get unique task id for state tracking."""
    task_id = sql_table_id(target)
    if partition_id:
        task_id += f"${partition_id}"
    return task_id


@dataclass
class Partition:
    """Return type for get_partition."""

    condition: str
    id: Optional[str] = None
    is_special: bool = False


def _override_query_with_fxa_id_in_extras(
    field_condition: str,
    target: DeleteTarget,
    sources: Iterable[DeleteSource],
    source_condition: str,
) -> str:
    """Override query to handle fxa_id nested in event extras in relay_backend_stable.events_v1."""
    sources = list(sources)
    if (
        target.table == "relay_backend_stable.events_v1"
        and len(target.fields) == 1
        and target.fields[0] == "events[*].extra.fxa_id"
        and len(sources) == 1
        and sources[0].table == "firefox_accounts.fxa_delete_events"
        and sources[0].field == "user_id"
    ):
        field_condition = (
            """
                EXISTS (
                  WITH user_ids AS (
                  SELECT
                    user_id_unhashed AS user_id
                  FROM
                  `moz-fx-data-shared-prod.firefox_accounts.fxa_delete_events`
                  WHERE """
            + " AND ".join((source_condition, *sources[0].conditions))
            + """)
                  SELECT 1
                  FROM UNNEST(events) AS e
                  JOIN UNNEST(e.extra) AS ex
                  JOIN user_ids u
                  ON ex.value = u.user_id
                  WHERE ex.key = 'fxa_id'
                )
                """
        )
    return field_condition


def delete_from_partition(
    dry_run: bool,
    partition: Partition,
    priority: str,
    source_condition: str,
    sources: Iterable[DeleteSource],
    target: DeleteTarget,
    use_dml: bool,
    sample_id: Optional[int] = None,
    temp_dataset: Optional[str] = None,
    clustering_fields: Optional[Iterable[str]] = None,
    **wait_for_job_kwargs,
):
    """Return callable to handle deletion requests for partitions of a target table."""
    job_config = bigquery.QueryJobConfig(dry_run=dry_run, priority=priority)
    # whole table operations must use DML to protect against dropping partitions in the
    # case of conflicting write operations in ETL, and special partitions must use DML
    # because they can't be set as a query destination.
    if partition.id is None or partition.is_special:
        use_dml = True
    elif sample_id is not None:
        use_dml = False
        job_config.destination = (
            f"{temp_dataset}.{target.table_id}_{partition.id}__sample_{sample_id}"
        )
        job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE
        job_config.clustering_fields = clustering_fields
    elif not use_dml:
        job_config.destination = f"{sql_table_id(target)}${partition.id}"
        job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE

    def create_job(client) -> bigquery.QueryJob:
        if use_dml:
            field_condition = " OR ".join(
                f"""
                {field} IN (
                  SELECT
                    {source.field}
                  FROM
                    `{sql_table_id(source)}`
                  WHERE
                """
                + " AND ".join((source_condition, *source.conditions))
                + ")"
                for field, source in zip(target.fields, sources)
            )

            # Temporary workaround for fxa_id nested in event extras in relay_backend_stable.events_v1
            # We'll be able to remove this once fxa_id is migrated to string metric
            # See https://mozilla-hub.atlassian.net/browse/DENG-7965 and 7964
            field_condition = _override_query_with_fxa_id_in_extras(
                field_condition, target, sources, source_condition
            )

            query = reformat(
                f"""
                DELETE
                  `{sql_table_id(target)}`
                WHERE
                  ({field_condition})
                  AND ({partition.condition})
                """
            )
        else:
            field_joins = "".join(
                (
                    f"""
                LEFT JOIN
                  (
                    SELECT
                      {source.field} AS _source_{index}
                    FROM
                      `{sql_table_id(source)}`
                    WHERE
                """
                    + " AND ".join((source_condition, *source.conditions))
                    + (f" AND sample_id = {sample_id}" if sample_id is not None else "")
                    + f"""
                  )
                  ON {field} = _source_{index}
                """
                )
                for index, (field, source) in enumerate(zip(target.fields, sources))
            )
            field_conditions = " AND ".join(
                f"_source_{index} IS NULL" for index, _ in enumerate(sources)
            )

            if partition.id is None:
                # only apply field conditions on partition.condition
                field_conditions = f"""
                ({partition.condition}) IS NOT TRUE
                OR ({field_conditions})
                """
                # always true partition condition to satisfy require_partition_filter
                partition_condition = f"""
                ({partition.condition}) IS NOT TRUE
                OR ({partition.condition})
                """
            else:
                partition_condition = partition.condition

            query = reformat(
                f"""
                SELECT
                  _target.*,
                FROM
                  `{sql_table_id(target)}` AS _target
                {field_joins}
                WHERE
                  ({field_conditions})
                  AND ({partition_condition})
                  {f" AND sample_id = {sample_id}" if sample_id is not None else ""}
                """
            )
        run_tense = "Would run" if dry_run else "Running"
        logging.debug(f"{run_tense} query: {query}")
        return client.query(query, job_config=job_config)

    return partial(
        wait_for_job, create_job=create_job, dry_run=dry_run, **wait_for_job_kwargs
    )


def delete_from_partition_with_sampling(
    dry_run: bool,
    partition: Partition,
    priority: str,
    source_condition: str,
    sources: Iterable[DeleteSource],
    target: DeleteTarget,
    use_dml: bool,
    sampling_parallelism: int,
    temp_dataset: str,
    **wait_for_job_kwargs,
):
    """Return callable to delete from a partition of a target table per sample id."""
    copy_job_config = bigquery.CopyJobConfig(
        write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
        create_disposition=bigquery.CreateDisposition.CREATE_IF_NEEDED,
    )
    target_table = f"{sql_table_id(target)}${partition.id}"

    def delete_by_sample(client) -> Union[bigquery.CopyJob, bigquery.QueryJob]:
        intermediate_clustering_fields = client.get_table(
            target_table
        ).clustering_fields

        tasks = [
            delete_from_partition(
                dry_run=dry_run,
                partition=partition,
                priority=priority,
                source_condition=source_condition,
                sources=sources,
                target=target,
                use_dml=use_dml,
                temp_dataset=temp_dataset,
                sample_id=s,
                clustering_fields=intermediate_clustering_fields,
                check_table_existence=True,
                **{
                    **wait_for_job_kwargs,
                    # override task id with sample id suffix
                    "task_id": f"{wait_for_job_kwargs['task_id']}__sample_{s}",
                },
            )
            for s in range(100)
        ]

        # Run all 100 delete functions in parallel, exception is raised without retry if any fail
        with ThreadPool(sampling_parallelism) as pool:
            jobs = [
                pool.apply_async(
                    task,
                    args=(client,),
                )
                for task in tasks
            ]

            results = [job.get() for job in jobs]

        intermediate_tables = [result.destination for result in results]
        run_tense = "Would copy" if dry_run else "Copying"
        logging.debug(
            f"{run_tense} {len(intermediate_tables)} "
            f"{[str(t) for t in intermediate_tables]} to {target_table}"
        )
        if not dry_run:
            copy_job = client.copy_table(
                sources=intermediate_tables,
                destination=target_table,
                job_config=copy_job_config,
            )
            # simulate query job properties for logging
            copy_job.total_bytes_processed = sum(
                [r.total_bytes_processed for r in results]
            )
            return copy_job
        else:
            # copy job doesn't have dry runs so dry run base partition for byte estimate
            return client.query(
                f"SELECT * FROM `{sql_table_id(target)}` WHERE {partition.condition}",
                job_config=bigquery.QueryJobConfig(dry_run=True),
            )

    return partial(
        wait_for_job,
        create_job=delete_by_sample,
        dry_run=dry_run,
        **wait_for_job_kwargs,
    )


def get_partition_expr(table):
    """Get the SQL expression to use for a table's partitioning field."""
    if table.range_partitioning:
        return table.range_partitioning.field
    if table.time_partitioning:
        return f"CAST({table.time_partitioning.field or '_PARTITIONTIME'} AS DATE)"


def get_partition(table, partition_expr, end_date, id_=None) -> Optional[Partition]:
    """Return a Partition for id_ unless it is a date on or after end_date."""
    if id_ is None:
        if table.time_partitioning:
            return Partition(condition=f"{partition_expr} < '{end_date}'")
        return Partition(condition="TRUE")
    if id_ == NULL_PARTITION_ID:
        if table.time_partitioning:
            return Partition(
                condition=f"{table.time_partitioning.field} IS NULL",
                id=id_,
                is_special=True,
            )
        return Partition(condition=f"{partition_expr} IS NULL", id=id_, is_special=True)
    if table.time_partitioning:
        date = datetime.strptime(id_, "%Y%m%d").date()
        if date < end_date:
            return Partition(f"{partition_expr} = '{date}'", id_)
        return None
    if table.range_partitioning:
        if id_ == OUTSIDE_RANGE_PARTITION_ID:
            return Partition(
                condition=f"{partition_expr} < {table.range_partitioning.range_.start} "
                f"OR {partition_expr} >= {table.range_partitioning.range_.end}",
                id=id_,
                is_special=True,
            )
        if table.range_partitioning.range_.interval > 1:
            return Partition(
                condition=f"{partition_expr} BETWEEN {id_} "
                f"AND {int(id_) + table.range_partitioning.range_.interval - 1}",
                id=id_,
            )
    return Partition(condition=f"{partition_expr} = {id_}", id=id_)


def list_partitions(
    client, table, partition_expr, end_date, max_single_dml_bytes, partition_limit
):
    """List the relevant partitions in a table."""
    partitions = [
        partition
        for partition in (
            [
                get_partition(table, partition_expr, end_date, row["partition_id"])
                for row in client.query(
                    dedent(
                        f"""
                        SELECT
                          partition_id
                        FROM
                          [{sql_table_id(table)}$__PARTITIONS_SUMMARY__]
                        """
                    ).strip(),
                    bigquery.QueryJobConfig(use_legacy_sql=True),
                ).result()
            ]
            if table.num_bytes > max_single_dml_bytes and partition_expr is not None
            else [get_partition(table, partition_expr, end_date)]
        )
        if partition is not None
    ]
    if partition_limit:
        return sorted(partitions, key=attrgetter("id"), reverse=True)[:partition_limit]
    return partitions


@dataclass
class Task:
    """Return type for delete_from_table."""

    table: bigquery.Table
    sources: Tuple[DeleteSource]
    partition_id: Optional[str]
    func: Callable[[bigquery.Client], bigquery.QueryJob]

    @property
    def partition_sort_key(self):
        """Return a tuple to control the order in which tasks will be handled.

        When used with reverse=True, handle tasks without partition_id, then
        tasks without time_partitioning, then most recent dates first.
        """
        return (
            self.partition_id is None,
            self.table.time_partitioning is None,
            self.partition_id,
        )


def delete_from_table(
    client,
    target,
    sources,
    dry_run,
    end_date,
    max_single_dml_bytes,
    partition_limit,
    sampling_parallelism,
    use_sampling,
    temp_dataset,
    **kwargs,
) -> Iterable[Task]:
    """Yield tasks to handle deletion requests for a target table."""
    try:
        table = client.get_table(sql_table_id(target))
    except NotFound:
        logging.warning(f"Skipping {sql_table_id(target)} due to NotFound exception")
        return ()  # type: ignore
    partition_expr = get_partition_expr(table)
    for partition in list_partitions(
        client, table, partition_expr, end_date, max_single_dml_bytes, partition_limit
    ):
        # no sampling for __NULL__ partition
        if use_sampling and not partition.is_special:
            kwargs["sampling_parallelism"] = sampling_parallelism
            delete_func: Callable = delete_from_partition_with_sampling
        else:
            if use_sampling:
                logging.warning(
                    "Cannot use sampling on full table deletion, "
                    f"{target.dataset_id}.{target.table_id} is too small to use sampling"
                )
            kwargs.pop("sampling_parallelism", None)
            delete_func = delete_from_partition

        yield Task(
            table=table,
            sources=sources,
            partition_id=partition.id,
            func=delete_func(
                dry_run=dry_run,
                partition=partition,
                target=target,
                sources=sources,
                task_id=get_task_id(target, partition.id),
                end_date=end_date,
                temp_dataset=temp_dataset,
                **kwargs,
            ),
        )


def main():
    """Process deletion requests."""
    args = parser.parse_args()
    if args.partition_limit is not None and not args.dry_run:
        parser.print_help()
        logging.warning("ERROR: --partition-limit specified without --dry-run")
    if len(args.sampling_tables) > 0 and args.temp_dataset is None:
        parser.error("--temp-dataset must be specified when using --sampling-tables")
    if args.start_date is None:
        args.start_date = args.end_date - timedelta(days=14)
    source_condition = (
        f"DATE(submission_timestamp) >= '{args.start_date}' "
        f"AND DATE(submission_timestamp) < '{args.end_date}'"
    )
    client_q = ClientQueue(
        args.billing_projects,
        args.parallelism,
        connection_pool_max_size=(
            max(args.parallelism * args.sampling_parallelism, 12)
            if len(args.sampling_tables) > 0
            else None
        ),
    )
    client = client_q.default_client
    states = {}
    if args.state_table:
        state_table_exists = False
        try:
            client.get_table(args.state_table)
            state_table_exists = True
        except NotFound:
            if not args.dry_run:
                client.create_table(
                    bigquery.Table(
                        args.state_table,
                        [
                            bigquery.SchemaField("task_id", "STRING"),
                            bigquery.SchemaField("job_id", "STRING"),
                            bigquery.SchemaField("job_created", "TIMESTAMP"),
                            bigquery.SchemaField("start_date", "DATE"),
                            bigquery.SchemaField("end_date", "DATE"),
                        ],
                    )
                )
                state_table_exists = True
        if state_table_exists:
            states = dict(
                client.query(
                    reformat(
                        f"""
                        SELECT
                          task_id,
                          job_id,
                        FROM
                          `{args.state_table}`
                        WHERE
                          end_date = '{args.end_date}'
                        ORDER BY
                          job_created
                        """
                    )
                ).result()
            )

    if args.environment == "telemetry":
        with ThreadPool(6) as pool:
            glean_targets = find_glean_targets(pool, client)
        targets_with_sources = (
            *DELETE_TARGETS.items(),
            *glean_targets.items(),
        )
    elif args.environment == "experiments":
        targets_with_sources = find_experiment_analysis_targets(client).items()
    elif args.environment == "pioneer":
        with ThreadPool(args.parallelism) as pool:
            targets_with_sources = find_pioneer_targets(
                pool, client, study_projects=args.pioneer_study_projects
            ).items()

    missing_sampling_tables = [
        t
        for t in args.sampling_tables
        if t not in [target.table for target, _ in targets_with_sources]
    ]
    if len(missing_sampling_tables) > 0:
        raise ValueError(
            f"{len(missing_sampling_tables)} sampling tables not found in "
            f"targets: {missing_sampling_tables}"
        )

    tasks = [
        task
        for target, sources in targets_with_sources
        if args.table_filter(target.table)
        for task in delete_from_table(
            client=client,
            target=replace(target, project=args.target_project or target.project),
            sources=[
                replace(source, project=args.source_project or source.project)
                for source in (sources if isinstance(sources, tuple) else (sources,))
            ],
            source_condition=source_condition,
            dry_run=args.dry_run,
            use_dml=args.use_dml,
            priority=args.priority,
            start_date=args.start_date,
            end_date=args.end_date,
            max_single_dml_bytes=args.max_single_dml_bytes,
            partition_limit=args.partition_limit,
            state_table=args.state_table,
            states=states,
            sampling_parallelism=args.sampling_parallelism,
            use_sampling=target.table in args.sampling_tables,
            temp_dataset=args.temp_dataset,
        )
    ]

    if not tasks:
        logging.error("No tables selected")
        parser.exit(1)
    # ORDER BY partition_sort_key DESC, sql_table_id ASC
    # https://docs.python.org/3/howto/sorting.html#sort-stability-and-complex-sorts
    tasks.sort(key=lambda task: sql_table_id(task.table))
    tasks.sort(key=attrgetter("partition_sort_key"), reverse=True)

    with ThreadPool(args.parallelism) as pool:
        if args.task_table and not args.dry_run:
            # record task information
            try:
                client.get_table(args.task_table)
            except NotFound:
                table = bigquery.Table(
                    args.task_table,
                    [
                        bigquery.SchemaField("task_id", "STRING"),
                        bigquery.SchemaField("start_date", "DATE"),
                        bigquery.SchemaField("end_date", "DATE"),
                        bigquery.SchemaField("target", "STRING"),
                        bigquery.SchemaField("target_rows", "INT64"),
                        bigquery.SchemaField("target_bytes", "INT64"),
                        bigquery.SchemaField("source_bytes", "INT64"),
                    ],
                )
                table.time_partitioning = bigquery.TimePartitioning()
                client.create_table(table)
            sources = list(set(source for task in tasks for source in task.sources))
            source_bytes = {
                source: job.total_bytes_processed
                for source, job in zip(
                    sources,
                    pool.starmap(
                        client.query,
                        [
                            (
                                reformat(
                                    f"""
                                    SELECT
                                      {source.field}
                                    FROM
                                      `{sql_table_id(source)}`
                                    WHERE
                                      {source_condition}
                                    """
                                ),
                                bigquery.QueryJobConfig(dry_run=True),
                            )
                            for source in sources
                        ],
                        chunksize=1,
                    ),
                )
            }
            step = 10000  # max 10K rows per insert
            for start in range(0, len(tasks), step):
                end = start + step
                BigQueryInsertError.raise_if_present(
                    errors=client.insert_rows_json(
                        args.task_table,
                        [
                            {
                                "task_id": get_task_id(task.table, task.partition_id),
                                "start_date": args.start_date.isoformat(),
                                "end_date": args.end_date.isoformat(),
                                "target": sql_table_id(task.table),
                                "target_rows": task.table.num_rows,
                                "target_bytes": task.table.num_bytes,
                                "source_bytes": sum(
                                    map(source_bytes.get, task.sources)
                                ),
                            }
                            for task in tasks[start:end]
                        ],
                    )
                )
        results = pool.map(
            client_q.with_client, (task.func for task in tasks), chunksize=1
        )
    jobs_by_table = defaultdict(list)
    for i, job in enumerate(results):
        jobs_by_table[tasks[i].table].append(job)
    bytes_processed = rows_deleted = 0
    for table, jobs in jobs_by_table.items():
        table_bytes_processed = sum(job.total_bytes_processed or 0 for job in jobs)
        bytes_processed += table_bytes_processed
        table_id = sql_table_id(table)
        if args.dry_run:
            logging.info(f"Would scan {table_bytes_processed} bytes from {table_id}")
        else:
            table_rows_deleted = sum(job.num_dml_affected_rows or 0 for job in jobs)
            rows_deleted += table_rows_deleted
            logging.info(
                f"Scanned {table_bytes_processed} bytes and "
                f"deleted {table_rows_deleted} rows from {table_id}"
            )
    if args.dry_run:
        logging.info(f"Would scan {bytes_processed} in total")
    else:
        logging.info(
            f"Scanned {bytes_processed} and deleted {rows_deleted} rows in total"
        )


if __name__ == "__main__":
    warnings.filterwarnings("ignore", module="google.auth._default")
    main()
