in bigquery_etl/shredder/delete.py [0:0]
def main():
"""Process deletion requests."""
args = parser.parse_args()
if args.partition_limit is not None and not args.dry_run:
parser.print_help()
logging.warning("ERROR: --partition-limit specified without --dry-run")
if len(args.sampling_tables) > 0 and args.temp_dataset is None:
parser.error("--temp-dataset must be specified when using --sampling-tables")
if args.start_date is None:
args.start_date = args.end_date - timedelta(days=14)
source_condition = (
f"DATE(submission_timestamp) >= '{args.start_date}' "
f"AND DATE(submission_timestamp) < '{args.end_date}'"
)
client_q = ClientQueue(
args.billing_projects,
args.parallelism,
connection_pool_max_size=(
max(args.parallelism * args.sampling_parallelism, 12)
if len(args.sampling_tables) > 0
else None
),
)
client = client_q.default_client
states = {}
if args.state_table:
state_table_exists = False
try:
client.get_table(args.state_table)
state_table_exists = True
except NotFound:
if not args.dry_run:
client.create_table(
bigquery.Table(
args.state_table,
[
bigquery.SchemaField("task_id", "STRING"),
bigquery.SchemaField("job_id", "STRING"),
bigquery.SchemaField("job_created", "TIMESTAMP"),
bigquery.SchemaField("start_date", "DATE"),
bigquery.SchemaField("end_date", "DATE"),
],
)
)
state_table_exists = True
if state_table_exists:
states = dict(
client.query(
reformat(
f"""
SELECT
task_id,
job_id,
FROM
`{args.state_table}`
WHERE
end_date = '{args.end_date}'
ORDER BY
job_created
"""
)
).result()
)
if args.environment == "telemetry":
with ThreadPool(6) as pool:
glean_targets = find_glean_targets(pool, client)
targets_with_sources = (
*DELETE_TARGETS.items(),
*glean_targets.items(),
)
elif args.environment == "experiments":
targets_with_sources = find_experiment_analysis_targets(client).items()
elif args.environment == "pioneer":
with ThreadPool(args.parallelism) as pool:
targets_with_sources = find_pioneer_targets(
pool, client, study_projects=args.pioneer_study_projects
).items()
missing_sampling_tables = [
t
for t in args.sampling_tables
if t not in [target.table for target, _ in targets_with_sources]
]
if len(missing_sampling_tables) > 0:
raise ValueError(
f"{len(missing_sampling_tables)} sampling tables not found in "
f"targets: {missing_sampling_tables}"
)
tasks = [
task
for target, sources in targets_with_sources
if args.table_filter(target.table)
for task in delete_from_table(
client=client,
target=replace(target, project=args.target_project or target.project),
sources=[
replace(source, project=args.source_project or source.project)
for source in (sources if isinstance(sources, tuple) else (sources,))
],
source_condition=source_condition,
dry_run=args.dry_run,
use_dml=args.use_dml,
priority=args.priority,
start_date=args.start_date,
end_date=args.end_date,
max_single_dml_bytes=args.max_single_dml_bytes,
partition_limit=args.partition_limit,
state_table=args.state_table,
states=states,
sampling_parallelism=args.sampling_parallelism,
use_sampling=target.table in args.sampling_tables,
temp_dataset=args.temp_dataset,
)
]
if not tasks:
logging.error("No tables selected")
parser.exit(1)
# ORDER BY partition_sort_key DESC, sql_table_id ASC
# https://docs.python.org/3/howto/sorting.html#sort-stability-and-complex-sorts
tasks.sort(key=lambda task: sql_table_id(task.table))
tasks.sort(key=attrgetter("partition_sort_key"), reverse=True)
with ThreadPool(args.parallelism) as pool:
if args.task_table and not args.dry_run:
# record task information
try:
client.get_table(args.task_table)
except NotFound:
table = bigquery.Table(
args.task_table,
[
bigquery.SchemaField("task_id", "STRING"),
bigquery.SchemaField("start_date", "DATE"),
bigquery.SchemaField("end_date", "DATE"),
bigquery.SchemaField("target", "STRING"),
bigquery.SchemaField("target_rows", "INT64"),
bigquery.SchemaField("target_bytes", "INT64"),
bigquery.SchemaField("source_bytes", "INT64"),
],
)
table.time_partitioning = bigquery.TimePartitioning()
client.create_table(table)
sources = list(set(source for task in tasks for source in task.sources))
source_bytes = {
source: job.total_bytes_processed
for source, job in zip(
sources,
pool.starmap(
client.query,
[
(
reformat(
f"""
SELECT
{source.field}
FROM
`{sql_table_id(source)}`
WHERE
{source_condition}
"""
),
bigquery.QueryJobConfig(dry_run=True),
)
for source in sources
],
chunksize=1,
),
)
}
step = 10000 # max 10K rows per insert
for start in range(0, len(tasks), step):
end = start + step
BigQueryInsertError.raise_if_present(
errors=client.insert_rows_json(
args.task_table,
[
{
"task_id": get_task_id(task.table, task.partition_id),
"start_date": args.start_date.isoformat(),
"end_date": args.end_date.isoformat(),
"target": sql_table_id(task.table),
"target_rows": task.table.num_rows,
"target_bytes": task.table.num_bytes,
"source_bytes": sum(
map(source_bytes.get, task.sources)
),
}
for task in tasks[start:end]
],
)
)
results = pool.map(
client_q.with_client, (task.func for task in tasks), chunksize=1
)
jobs_by_table = defaultdict(list)
for i, job in enumerate(results):
jobs_by_table[tasks[i].table].append(job)
bytes_processed = rows_deleted = 0
for table, jobs in jobs_by_table.items():
table_bytes_processed = sum(job.total_bytes_processed or 0 for job in jobs)
bytes_processed += table_bytes_processed
table_id = sql_table_id(table)
if args.dry_run:
logging.info(f"Would scan {table_bytes_processed} bytes from {table_id}")
else:
table_rows_deleted = sum(job.num_dml_affected_rows or 0 for job in jobs)
rows_deleted += table_rows_deleted
logging.info(
f"Scanned {table_bytes_processed} bytes and "
f"deleted {table_rows_deleted} rows from {table_id}"
)
if args.dry_run:
logging.info(f"Would scan {bytes_processed} in total")
else:
logging.info(
f"Scanned {bytes_processed} and deleted {rows_deleted} rows in total"
)