projects/database-archival/src/database_archival/dag/tasks/data_pruning.py (188 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provides tasks to perform data pruning on the database.""" import enum import time from airflow.decorators import task from airflow.decorators import task_group from airflow.models import expandinput from airflow.operators import python as airflow_operators from airflow.providers.google.cloud.operators import bigquery from airflow.providers.google.common.utils import id_token_credentials from google.auth.transport import requests from database_archival.common.models import database from database_archival.dag.tasks import data_pruning_preparation from database_archival.dag.utils import task_namer from database_archival.dag import config from typing import Iterable, Optional TASK_GROUP_NAME = 'delete_data_from_database' TASK_NAME_GET_BATCHES = 'get_batches_for_loop' TASK_GROUP_NAME_DELETE_BATCH_AND_UPDATE_PROGRESS = ( 'delete_and_update_data_batch' ) TASK_NAME_DELETE_BATCH = 'delete_data_from_database' TASK_GROUP_UPDATE_PROGRESS_BATCH = 'update_pruned_rows_status' TASK_NAME_GET_CONFIG_FOR_STATUS_UPDATE = 'get_update_prune_status_job_config' TASK_NAME_UPDATE_STATUS = 'update_pruned_rows_status_job' class _TaskPriority(enum.Enum): """Defines Airflow Task priorities. The higher, the higher the priority. The actual values are arbitrary, the relative value (which is higher) is all it matters. """ DEFAULT = 1 LOW = 10 MEDIUM = 100 HIGH = 1_000 @task_group(group_id=TASK_GROUP_NAME) def delete_data_from_database( *, bigquery_location: str, database_type: database.DatabaseType, database_instance_name: Optional[str] = None, database_host: Optional[str] = None, database_name: str, database_table_name: str, table_primary_key_columns: Iterable[str], database_username: str, database_password: Optional[str] = None, database_password_secret: Optional[str] = None, ): """Processes batches and requests deletion of data from the database. Args: bigquery_location: location where BigQuery jobs will be run. database_type: type of database where the data is hosted. database_instance_name: intance name of the database. AlloyDB Format: projects/<project_id>/locations/<region_id>/ clusters/<cluster_id>/instances/<instance_id> Cloud SQL Format: <project_id>:<region_id>:<instance_name> database_host: hostname or IP where the database is located. database_name: name of the database where the table is located. database_table_name: name of the table that will be pruned. Used also for identifying tasks. table_primary_key_columns: list of primary keys for the table. database_username: username to access the database. database_password: password for the username to access the database. database_password_secret: password for the username to access the database, stored in Secret Manager. Format: "projects/<projectId>/secrets/<secretId>/versions/<versionId>". """ @task(task_id=TASK_NAME_GET_BATCHES) def get_batches_for_loop(): context = airflow_operators.get_current_context() batches = context['ti'].xcom_pull( task_ids=task_namer.get_task_name( database_table_name=database_table_name, task_group_name=data_pruning_preparation.TASK_GROUP_NAME, task_name=data_pruning_preparation.TASK_NAME_GET_BATCH_LIST, ), key='return_value', ) return batches delete_and_update_data_batch.partial( database_table_name=database_table_name, bigquery_location=bigquery_location, database_type=database_type, database_instance_name=database_instance_name, database_host=database_host, database_name=database_name, database_username=database_username, database_password=database_password, database_password_secret=database_password_secret, table_primary_key_columns=table_primary_key_columns, ).expand(batch_number_expand=get_batches_for_loop()) @task_group(group_id=TASK_GROUP_NAME_DELETE_BATCH_AND_UPDATE_PROGRESS) def delete_and_update_data_batch( *, bigquery_location: str, database_type: database.DatabaseType, database_instance_name: Optional[str] = None, database_host: Optional[str] = None, database_name: str, database_table_name: str, table_primary_key_columns: Iterable[str], database_username: str, database_password: Optional[str] = None, database_password_secret: Optional[str] = None, batch_number_expand: expandinput.MappedArgument, ): """Requests to deletes a batch of data from database and updates status. Args: bigquery_location: location where BigQuery jobs will be run. database_type: type of database where the data is hosted. database_instance_name: intance name of the database. AlloyDB Format: projects/<project_id>/locations/<region_id>/ clusters/<cluster_id>/instances/<instance_id> Cloud SQL Format: <project_id>:<region_id>:<instance_name> database_host: hostname or IP where the database is located. database_name: name of the database where the table is located. database_table_name: name of the table that will be pruned. Used also for identifying tasks. table_primary_key_columns: list of primary keys for the table. database_username: username to access the database. database_password: password for the username to access the database. database_password_secret: password for the username to access the database, stored in Secrets Manager. Format: "projects/<projectId>/secrets/<secretId>/versions/<versionId>". batch_number_expand: number of the batch to delete. """ @task( task_id=TASK_NAME_DELETE_BATCH, # Limit the pruning tasks to run only one batch at a time per table to # avoid overloading the database with delete requests in parallel for # the same table. max_active_tis_per_dagrun=1, pre_execute=time.sleep(config.TIME_BETWEEN_DELETE_BATCHES_IN_SECONDS), ) def call_cloud_function_to_prune_data(batch_number): """Calls Cloud Function to request deleting batch from the database. Args: batch_number: batch number that will be pruned. """ context = airflow_operators.get_current_context() dag_run_id = context['run_id'] dag_date = context['ds_nodash'] snapshot_progress_table_name = context['ti'].xcom_pull( task_ids=task_namer.get_task_name( database_table_name=database_table_name, task_group_name=data_pruning_preparation.TASK_GROUP_NAME, task_name=data_pruning_preparation.TASK_NAME_GET_CONFIG, ), key=data_pruning_preparation.VAR_NAME_PRUNE_PROGRESS_TABLE, ) credentials = id_token_credentials.get_default_id_token_credentials( config.CLOUD_FUNCTION_URL_DATA_DELETION ) response = requests.AuthorizedSession(credentials).post( url=config.CLOUD_FUNCTION_URL_DATA_DELETION, json={ 'bigquery_location': bigquery_location, 'snapshot_progress_table_name': snapshot_progress_table_name, 'snapshot_date': dag_date, 'snapshot_run_id': dag_run_id, 'snapshot_batch': batch_number, 'database_type': database_type.value, # enum to string value. 'database_instance_name': database_instance_name, 'database_host': database_host, 'database_name': database_name, 'database_username': database_username, 'database_password': database_password, 'database_password_secret': database_password_secret, 'database_table_name': database_table_name, 'table_primary_key_columns': table_primary_key_columns, }, timeout=config.DATABASE_TIMEOUT, ) if response.status_code != 200: raise RuntimeError( 'Cloud Function call to prune data failed with status code ' f'{response.status_code}. Response text: {response.text}. ' 'Check Cloud Function logs for more details.' ) return response.json() update_batch_status_in_bigquery = update_pruned_rows_status( database_table_name=database_table_name, bigquery_location=bigquery_location, batch_number_expand=batch_number_expand, ) # pylint: disable-next=pointless-statement, expression-not-assigned ( call_cloud_function_to_prune_data(batch_number=batch_number_expand) >> update_batch_status_in_bigquery ) @task_group(group_id=TASK_GROUP_UPDATE_PROGRESS_BATCH) def update_pruned_rows_status( *, database_table_name: str, bigquery_location: str, batch_number_expand: expandinput.MappedArgument, ): """Updates the batch pruning status in BigQuery. Args: database_table_name: name of the table that was pruned. Used for identifying tasks. bigquery_location: location where BigQuery jobs will be run. batch_number: number of the batch that was pruned. """ @task( task_id=TASK_NAME_GET_CONFIG_FOR_STATUS_UPDATE, priority_weight=_TaskPriority.MEDIUM.value, # Prioritize over delete. ) def get_update_prune_status_job_config(batch_number: int): """Creates the job config to update pruned rows status. Args: batch_number: batch number for which to create job config. """ context = airflow_operators.get_current_context() dag_date = context['ds_nodash'] dag_run_id = context['run_id'] snapshot_progress_table_name = context['ti'].xcom_pull( task_ids=task_namer.get_task_name( database_table_name=database_table_name, task_group_name=data_pruning_preparation.TASK_GROUP_NAME, task_name=data_pruning_preparation.TASK_NAME_GET_CONFIG, ), key=data_pruning_preparation.VAR_NAME_PRUNE_PROGRESS_TABLE, ) set_fields = [ f'{database.FIELD_NAME_PRUNE_STATUS} = TRUE', ] parsed_dag_date = f'PARSE_DATE("%Y%m%d", "{dag_date}")' condition_fields = [ f'{database.FIELD_NAME_SNAPSHOT_RUN_ID} = "{dag_run_id}"', f'{database.FIELD_NAME_SNAPSHOT_DATE} = {parsed_dag_date}', f'{database.FIELD_NAME_PRUNE_BATCH} = {batch_number}', ] return { 'query': { 'query': ( f'UPDATE {snapshot_progress_table_name} ' f"SET {', '.join(set_fields)} " f"WHERE {' AND '.join(condition_fields)}" ), 'use_legacy_sql': False, }, } update_pruned_rows_status_job = bigquery.BigQueryInsertJobOperator( task_id=TASK_NAME_UPDATE_STATUS, priority_weight=_TaskPriority.HIGH.value, # Prioritize over config. configuration=get_update_prune_status_job_config( batch_number=batch_number_expand, ), location=bigquery_location, ) # pylint: disable-next=pointless-statement, expression-not-assigned update_pruned_rows_status_job