src/translation/dags/batch_sql_translation.py (376 lines of code) (raw):

# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import csv import datetime import logging import os from pathlib import Path from airflow import models from airflow.operators.empty import EmptyOperator from airflow.operators.python import BranchPythonOperator, PythonOperator from airflow.operators.trigger_dagrun import TriggerDagRunOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.sensors.python import PythonSensor from airflow.utils.trigger_rule import TriggerRule from google.api_core.client_info import ClientInfo from google.cloud import bigquery, storage from translation_utils import csv_utils from translation_utils import translation_stats_utils as stats_utils from translation_utils.bigquery_migration_v2 import ( create_single_task_migration_workflow, get_migration_workflow_state, ) from common_utils import custom_user_agent, storage_utils from common_utils.operators.reporting_operator import ReportingOperator PROJECT_ID = os.environ.get("GCP_PROJECT_ID") TRANSLATION_REPORT_FILENAME = os.environ.get( "TRANSLATION_REPORT_FILENAME", "batch_translation_report.csv" ) SKIP_VALIDATION_DIR_NAME = "skip_validation" REPORT_FILENAME = "batch_translation_report.csv" MAP_FILENAME = "consumed_name_map.json" TRANSLATION_STATS_TABLE_ID = f"{PROJECT_ID}.dmt_logs.dmt_translation_results" TRANSLATION_AGG_RESULTS_TABLE_ID = ( f"{PROJECT_ID}.dmt_logs.dmt_translation_aggregated_results" ) ERROR_FILTER_KEY = "translationErrorFilterRules" TRANSLATION_CONFIG_KEY = "type" TRANSLATION_CONFIG_DDL = "ddl" TRANSLATION_CONFIG_SQL = "sql" TRANSLATION_CONFIG_DML = "dml" TRANSLATION_CONFIG_DEFAULT_VALUE = TRANSLATION_CONFIG_DDL DDL_FOLDER = "ddl" SCHEMA_DAG_ID = "schema_dag" VALIDATION_GKE_TYPE = "gke" VALIDATION_CRUN_TYPE = "cloudrun" VALIDATION_DEFAULT_TYPE = VALIDATION_GKE_TYPE VALIDATION_TYPE_TO_DAG_ID_MAPPING = { VALIDATION_CRUN_TYPE: "validation_crun_dag", VALIDATION_GKE_TYPE: "validation_dag", } VALIDATION_DAG_ID = "validation_dag" VALIDATION_CRUN_DAG_ID = "validation_crun_dag" DML_VALIDATION_DAG_ID = "dml_validation_dag" gcs_util = storage_utils.StorageUtils() default_dag_args = {"start_date": datetime.datetime(2022, 1, 1)} storage_client = storage.Client( client_info=ClientInfo(user_agent=custom_user_agent.USER_AGENT) ) def get_validation_dag_id(validation_mode): if validation_mode in VALIDATION_TYPE_TO_DAG_ID_MAPPING: return VALIDATION_TYPE_TO_DAG_ID_MAPPING[validation_mode] else: return VALIDATION_TYPE_TO_DAG_ID_MAPPING[VALIDATION_DEFAULT_TYPE] def is_ddl_run(config): """ Helper function to determine if the current run is a DDL run """ return ( TRANSLATION_CONFIG_KEY not in config or config[TRANSLATION_CONFIG_KEY].casefold() == "ddl" ) def _create_translation_workflow(ti, **kwargs) -> None: translation_config = ast.literal_eval(kwargs["dag_run"].conf["config"])["config"] response = create_single_task_migration_workflow( PROJECT_ID, translation_config["migrationTask"] ) logging.info(f"workflow info: {response}") ti.xcom_push(key="workflow_info", value=response) ti.xcom_push(key="config", value=translation_config) def _poll_workflow_state(ti) -> bool: [workflow_info] = ti.xcom_pull( key="workflow_info", task_ids=["create_translation_workflow"] ) response = get_migration_workflow_state(workflow_info["name"]) logging.info(f"retrieved workflow state. state: {response['state']}") if response["state"] == "COMPLETED": ti.xcom_push(key="workflow_info", value=response) return True else: return False def _get_failed_files_from_csv(ti) -> None: config = ti.xcom_pull(key="config", task_ids="create_translation_workflow") error_row_filter_rules = ( config[ERROR_FILTER_KEY] if ERROR_FILTER_KEY in config else None ) error_row_filter = csv_utils.rules_filter(error_row_filter_rules) source_path = config["migrationTask"]["translationConfigDetails"]["gcsSourcePath"] target_path = config["migrationTask"]["translationConfigDetails"]["gcsTargetPath"] target_bucket, target_folder = gcs_util.parse_bucket_and_blob_from_path(target_path) gcs_hook = GCSHook() report_object_name = storage_utils.append_blob_name_to_path( target_folder, REPORT_FILENAME ) report = gcs_hook.download_as_byte_array(target_bucket, report_object_name).decode( "utf-8" ) csv_reader = csv.reader(report.splitlines()) failed_files = set() errors = [] headers = next(csv_reader) logging.info("scanning failed files:") for row in csv_reader: row_dict = csv_utils.row_to_dict(headers, row) file_path = "" if row_dict["FilePath"] is not None and row_dict["FilePath"] != "": file_path = os.path.relpath(row_dict["FilePath"], start=source_path) elif row_dict["FileName"] is not None and row_dict["FileName"] != "": file_path = row_dict["FileName"] else: logging.info("File path and name is blank in CSV file") continue # If name and path both are blank skip the iteration failed_file_path = f"{target_folder}/{file_path}" if not error_row_filter or not error_row_filter(row_dict): failed_files.add(failed_file_path) logging.info(f"{failed_file_path} contains errors") errors.append(stats_utils.csv_row_to_record(row_dict, None, False)) else: logging.info(f"{failed_file_path} contains errors (filtered)") errors.append(stats_utils.csv_row_to_record(row_dict, None, True)) ti.xcom_push(key="files", value=list(failed_files)) ti.xcom_push(key="errors", value=errors) def _get_all_translated_files(ti) -> None: translation_config = ti.xcom_pull( key="config", task_ids="create_translation_workflow" ) target_path = translation_config["migrationTask"]["translationConfigDetails"][ "gcsTargetPath" ] target_bucket, target_folder = gcs_util.parse_bucket_and_blob_from_path(target_path) translated_files = set() for blob in storage_client.list_blobs(target_bucket, prefix=target_folder): file_full_path = Path(blob.name) filename = file_full_path.name if ( blob.name[-1] != "/" and filename != REPORT_FILENAME and filename != MAP_FILENAME and SKIP_VALIDATION_DIR_NAME not in file_full_path.parts ): translated_files.add(blob.name) ti.xcom_push(key="files", value=list(translated_files)) def _get_successfully_translated_files(ti) -> None: [all_files, failed_files] = ti.xcom_pull( key="files", task_ids=["get_all_translated_files", "get_failed_files_from_csv"], ) config = ti.xcom_pull(key="config", task_ids="create_translation_workflow") target_path = config["migrationTask"]["translationConfigDetails"]["gcsTargetPath"] _, target_folder = gcs_util.parse_bucket_and_blob_from_path(target_path) successfully_translated_files = set(all_files).difference(set(failed_files)) if len(successfully_translated_files): ti.xcom_push(key="files", value=list(successfully_translated_files)) file_names = [] for filepath in successfully_translated_files: relative_path = os.path.relpath(filepath, start=target_folder) folder, _ = os.path.split(relative_path) if is_ddl_run(config) or folder != DDL_FOLDER: file_names.append(relative_path) ti.xcom_push(key="file_names", value=file_names) else: raise Exception("No files were successfully translated.") """ Method: _download_files Description: This task definition download files from output GCS bucket and store it in airflow data gcs directory Arguments: ti, **kwargs """ def _download_files(ti): config = ti.xcom_pull(key="config", task_ids="create_translation_workflow") file_list = ti.xcom_pull( key="file_names", task_ids="get_successfully_translated_files" ) translation_target_path = config["migrationTask"]["translationConfigDetails"][ "gcsTargetPath" ] bucket, folder = gcs_util.parse_bucket_and_blob_from_path(translation_target_path) logging.info(file_list) if len(file_list) == 0: logging.info("SQL File list is empty and there are no SQL files to download") else: gcs_hook = GCSHook() for filename in file_list: local_filename = f"/home/airflow/gcs/data/{filename}" object_name = f"{folder}/{filename}" # make sure local folder exists before downloading local_folder, _ = os.path.split(local_filename) if not os.path.exists(local_folder): os.makedirs(local_folder) # download file gcs_hook.download( bucket_name=bucket, object_name=object_name, filename=local_filename, ) logging.info(os.listdir("/home/airflow/gcs/data")) def _determine_next_dag(ti): config = ti.xcom_pull(key="config", task_ids="create_translation_workflow") run_type = ( config[TRANSLATION_CONFIG_KEY] if TRANSLATION_CONFIG_KEY in config else TRANSLATION_CONFIG_DEFAULT_VALUE ) if run_type == TRANSLATION_CONFIG_DDL: return "invoke_schema_dag" elif run_type == TRANSLATION_CONFIG_SQL: validation_mode = config["validation_config"].get("validation_mode") validation_dag_id = get_validation_dag_id(validation_mode) if validation_dag_id == VALIDATION_DAG_ID: return "invoke_validation_dag" else: return "invoke_validation_crun_dag" elif run_type == TRANSLATION_CONFIG_DML: return "invoke_dml_validation_dag" else: raise ValueError(f"invalid value for translation field: {run_type}") def _save_stats(ti): translation_config = ti.xcom_pull( key="config", task_ids="create_translation_workflow" ) all_translated_files = ti.xcom_pull( key="files", task_ids="get_all_translated_files" ) successfully_translated_files = ti.xcom_pull( key="files", task_ids="get_successfully_translated_files" ) failed_files = ti.xcom_pull(key="files", task_ids="get_failed_files_from_csv") translation_errors = ti.xcom_pull( key="errors", task_ids="get_failed_files_from_csv" ) workflow_info = ti.xcom_pull(key="workflow_info", task_ids="poll_workflow_state") logging.info(f"workflow info: {workflow_info}") stats = [ stats_utils.new_record(file.split("/")[-1], None, "OK") for file in successfully_translated_files ] stats += translation_errors for record in stats: record["unique_id"] = translation_config["unique_id"] record["create_time"] = workflow_info["createTime"] record["name"] = workflow_info["name"] translation_run = ( translation_config[TRANSLATION_CONFIG_KEY] if TRANSLATION_CONFIG_KEY in translation_config else "ddl" ) record["type"] = translation_run.casefold() bq_client = bigquery.Client( client_info=ClientInfo(user_agent=custom_user_agent.USER_AGENT) ) if stats == []: logging.info("Translation stats are empty. Please check translation csv file. ") else: bq_client.insert_rows_json(TRANSLATION_STATS_TABLE_ID, stats) translation_summary_csv_path = f"https://console.cloud.google.com/storage/browser/_details/{translation_config['migrationTask']['translationConfigDetails']['gcsTargetPath'].split('//')[-1]}/batch_translation_report.csv;tab=live_object?{PROJECT_ID}" rows_json_list = [ { "unique_id": translation_config["unique_id"], "total_files": len(all_translated_files), "successful_files": len(successfully_translated_files), "failed_files": len(failed_files), "input_files_path": translation_config["migrationTask"][ "translationConfigDetails" ]["gcsSourcePath"], "output_files_path": translation_config["migrationTask"][ "translationConfigDetails" ]["gcsTargetPath"], "translation_summary_csv_path": translation_summary_csv_path, } ] if rows_json_list == []: logging.info("Translation Aggregate Stats are empty. ") else: bq_client.insert_rows_json(TRANSLATION_AGG_RESULTS_TABLE_ID, rows_json_list) with models.DAG( "batch_sql_translation", schedule=None, default_args=default_dag_args, render_template_as_native_obj=True, ) as dag: # tasks create_translation_workflow = PythonOperator( task_id="create_translation_workflow", python_callable=_create_translation_workflow, dag=dag, ) get_failed_files_from_csv = PythonOperator( task_id="get_failed_files_from_csv", python_callable=_get_failed_files_from_csv, dag=dag, ) get_all_translated_files = PythonOperator( task_id="get_all_translated_files", python_callable=_get_all_translated_files, dag=dag, ) get_successfully_translated_files = PythonOperator( task_id="get_successfully_translated_files", python_callable=_get_successfully_translated_files, dag=dag, ) save_stats = PythonOperator( task_id="save_stats", python_callable=_save_stats, dag=dag ) download_files = PythonOperator( task_id="download_files", python_callable=_download_files, dag=dag ) determine_next_dag = BranchPythonOperator( task_id="determine_next_dag", python_callable=_determine_next_dag, dag=dag, ) invoke_schema_dag = TriggerDagRunOperator( task_id="invoke_schema_dag", trigger_dag_id=SCHEMA_DAG_ID, conf={ "config": "{{ ti.xcom_pull(task_ids='create_translation_workflow', key='config') }}", "files": "{{ ti.xcom_pull(task_ids='get_successfully_translated_files', key='file_names') }}", }, dag=dag, ) invoke_validation_dag = TriggerDagRunOperator( task_id="invoke_validation_dag", trigger_dag_id=VALIDATION_DAG_ID, conf={ "config": "{{ ti.xcom_pull(task_ids='create_translation_workflow', key='config') }}", "files": "{{ ti.xcom_pull(task_ids='get_successfully_translated_files', key='file_names') }}", }, dag=dag, ) invoke_validation_crun_dag = TriggerDagRunOperator( task_id="invoke_validation_crun_dag", trigger_dag_id=VALIDATION_CRUN_DAG_ID, conf={ "config": "{{ ti.xcom_pull(task_ids='create_translation_workflow', key='config') }}", "files": "{{ ti.xcom_pull(task_ids='get_successfully_translated_files', key='file_names') }}", }, dag=dag, ) invoke_dml_validation_dag = TriggerDagRunOperator( task_id="invoke_dml_validation_dag", trigger_dag_id=DML_VALIDATION_DAG_ID, conf={ "config": "{{ ti.xcom_pull(task_ids='create_translation_workflow', key='config') }}", "files": "{{ ti.xcom_pull(task_ids='get_successfully_translated_files', key='file_names') }}", }, dag=dag, ) end_task = EmptyOperator(task_id="end", dag=dag) # sensors poll_workflow_state = PythonSensor( task_id="poll_workflow_state", python_callable=_poll_workflow_state, dag=dag, poke_interval=15, ) dag_report = ReportingOperator( task_id="dag_report", trigger_rule=TriggerRule.ALL_DONE, # Ensures this task runs even if upstream fails configuration="{{ ti.xcom_pull(task_ids='create_translation_workflow', key='config') }}", dag=dag, ) # dependencies create_translation_workflow >> poll_workflow_state ( poll_workflow_state >> [get_failed_files_from_csv, get_all_translated_files] >> get_successfully_translated_files ) get_successfully_translated_files >> [download_files, save_stats] ( download_files >> determine_next_dag >> [ invoke_schema_dag, invoke_validation_dag, invoke_validation_crun_dag, invoke_dml_validation_dag, ] >> end_task >> dag_report )