composer-dags/demo_pipeline_composer.py (200 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------------
# Load The Dependencies
# --------------------------------------------------------------------------------
import json
import uuid
import re
from airflow import models
import google.auth
from airflow.models.variable import Variable
from airflow.providers.google.cloud.operators.dataflow import DataflowStartFlexTemplateOperator
from airflow.operators import empty
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.utils.task_group import TaskGroup
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.providers.google.cloud.sensors.dataform import DataformWorkflowInvocationStateSensor
from airflow.providers.google.cloud.operators.dataform import (
DataformCreateCompilationResultOperator,
DataformCreateWorkflowInvocationOperator,
)
from airflow.providers.google.cloud.operators.dataproc import (
DataprocCreateBatchOperator,
DataprocGetBatchOperator
)
from airflow.providers.google.cloud.sensors.dataproc import DataprocBatchSensor
from google.cloud.dataform_v1beta1 import WorkflowInvocation
from google.cloud import storage
from datetime import datetime, timedelta
# --------------------------------------------------------------------------------
# Read variables from GCS parameters file for the job
# --------------------------------------------------------------------------------
storage_client = storage.Client()
jobs_bucket = Variable.get("DATA_TRANSFORMATION_GCS_BUCKET")
batch_id = f"aef-{str(uuid.uuid4())}"
def extract_job_params(job_name, function_name, encoding='utf-8'):
"""Extracts parameters from a JSON job file.
Args:
bucket_name: Bucket containing the JSON parameters file .
Returns:
A dictionary containing the extracted parameters.
"""
json_file_path = f'gs://{jobs_bucket}/{function_name}/{job_name}.json'
parts = json_file_path.replace("gs://", "").split("/")
bucket_name = parts[0]
object_name = "/".join(parts[1:])
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(object_name)
json_data = blob.download_as_bytes()
params = json.loads(json_data.decode(encoding))
return params
# --------------------------------------------------------------------------------
# Set default arguments
# --------------------------------------------------------------------------------
# If you are running Airflow in more than one time zone
# see https://airflow.apache.org/docs/apache-airflow/stable/timezone.html
# for best practices
yesterday = datetime.now() - timedelta(days=1)
default_args = {
'owner': 'airflow',
'start_date': yesterday,
'depends_on_past': False,
'email': [''],
'email_on_failure': False,
'email_on_retry': False,
'retries': 0,
'retry_delay': timedelta(minutes=5)
}
sample_serverless_spark_mainframe_ingestion = extract_job_params('sample_serverless_spark_mainframe_ingestion','dataproc-serverless-job-executor')
for key, value in sample_serverless_spark_mainframe_ingestion.items():
default_args['sample_serverless_spark_mainframe_ingestion'+key] = value
sample_jdbc_dataflow_ingestion = extract_job_params('sample_jdbc_dataflow_ingestion','dataflow-flextemplate-job-executor')
for key, value in sample_jdbc_dataflow_ingestion.items():
default_args['sample_jdbc_dataflow_ingestion'+key] = value
run_dataform_tag = extract_job_params('run_dataform_tag','dataform-tag-executor')
for key, value in run_dataform_tag.items():
default_args['run_dataform_tag'+key] = value
start_date_str = yesterday.strftime('%Y-%m-%d')
end_date_str = datetime.today().strftime('%Y-%m-%d')
# --------------------------------------------------------------------------------
# Main DAG
# --------------------------------------------------------------------------------
with models.DAG(
'demo_pipeline_composer',
default_args=default_args,
params={
"start_date_str": start_date_str,
"end_date_str": end_date_str
},
catchup=False,
schedule_interval=None) as dag:
start = empty.EmptyOperator(
task_id='start',
trigger_rule='all_success'
)
# Start level group definition
with TaskGroup(group_id="Level_1") as tg_Level_1:
# Start thread group definition
with TaskGroup(group_id="Level_1_Thread_1") as tg_level_1_Thread_1:
with TaskGroup(group_id="sample_serverless_spark_mainframe_ingestion") as sample_serverless_spark_mainframe_ingestion:
def push_batch_id_to_xcom(**context):
context['task_instance'].xcom_push(key='batch-id', value=batch_id)
push_batch_id_to_xcom_sample_serverless_spark_mainframe_ingestion = PythonOperator(
task_id='batch_id',
python_callable=push_batch_id_to_xcom,
provide_context=True
)
# create batch
create_batch_for_sample_serverless_spark_mainframe_ingestion = DataprocCreateBatchOperator(
task_id="create_batch_for_sample_serverless_spark_mainframe_ingestion",
batch={
"spark_batch": {
"jar_file_uris": [default_args['sample_serverless_spark_mainframe_ingestion' + 'jar_file_location']],
"main_class": default_args['sample_serverless_spark_mainframe_ingestion' + 'spark_app_main_class'],
"args": default_args['sample_serverless_spark_mainframe_ingestion' + 'spark_args'],
},
"runtime_config": {
"version": default_args['sample_serverless_spark_mainframe_ingestion' + 'dataproc_serverless_runtime_version'],
"properties": default_args['sample_serverless_spark_mainframe_ingestion' + 'spark_app_properties'],
},
"environment_config": {
"execution_config": {
"service_account": default_args['sample_serverless_spark_mainframe_ingestion' + 'dataproc_service_account'],
"subnetwork_uri": f"projects/{default_args['sample_serverless_spark_mainframe_ingestion' + 'dataproc_serverless_project_id']}/{default_args['sample_serverless_spark_mainframe_ingestion' + 'subnetwork']}"
}
}
},
batch_id="{{ task_instance.xcom_pull(task_ids='Level_1.Level_1_Thread_1.sample_serverless_spark_mainframe_ingestion.batch_id', key='batch-id') }}",
project_id=default_args['sample_serverless_spark_mainframe_ingestion' + 'dataproc_serverless_project_id'],
region=default_args['sample_serverless_spark_mainframe_ingestion' + 'dataproc_serverless_region'],
deferrable=True
)
wait_for_batch_completion_for_sample_serverless_spark_mainframe_ingestion = DataprocBatchSensor(
task_id='wait_for_batch_completion_for_sample_serverless_spark_mainframe_ingestion',
batch_id="{{ task_instance.xcom_pull(task_ids='Level_1.Level_1_Thread_1.sample_serverless_spark_mainframe_ingestion.batch_id', key='batch-id') }}",
region=default_args['sample_serverless_spark_mainframe_ingestion' + 'dataproc_serverless_region'],
poke_interval=400,
timeout=3600,
soft_fail=True
)
get_batch_for_sample_serverless_spark_mainframe_ingestion = DataprocGetBatchOperator(
task_id="get_batch_for_sample_serverless_spark_mainframe_ingestion",
batch_id="{{ task_instance.xcom_pull(task_ids='Level_1.Level_1_Thread_1.sample_serverless_spark_mainframe_ingestion.batch_id', key='batch-id') }}",
region=default_args['sample_serverless_spark_mainframe_ingestion' + 'dataproc_serverless_region']
)
push_batch_id_to_xcom_sample_serverless_spark_mainframe_ingestion >> create_batch_for_sample_serverless_spark_mainframe_ingestion >> wait_for_batch_completion_for_sample_serverless_spark_mainframe_ingestion >> get_batch_for_sample_serverless_spark_mainframe_ingestion
sample_serverless_spark_mainframe_ingestion
# End thread group definition # Start thread group definition
with TaskGroup(group_id="Level_1_Thread_2") as tg_level_1_Thread_2:
with TaskGroup(group_id="sample_jdbc_dataflow_ingestion") as sample_jdbc_dataflow_ingestion:
dataflow_job_name = re.sub(r"^\d+", "",re.sub(r"[^a-z0-9+]", "", "sample_jdbc_dataflow_ingestion"))
dataflow_job_name = re.sub(r"^\d+", "", dataflow_job_name)
gcs_path = "gs://dataflow-templates-{region}/{version}/flex/{template}".format(region=default_args['sample_jdbc_dataflow_ingestion'+'dataflow_location'],
version=default_args['sample_jdbc_dataflow_ingestion'+'dataflow_template_version'],
template=default_args['sample_jdbc_dataflow_ingestion'+'dataflow_template_name'])
body = {
"launchParameter": {
"jobName": dataflow_job_name,
"parameters": default_args['sample_jdbc_dataflow_ingestion'+'dataflow_job_params'],
"containerSpecGcsPath": gcs_path,
"environment": {
"tempLocation": "gs://{bucket}/dataflow/temp".format(bucket=default_args['sample_jdbc_dataflow_ingestion' + 'dataflow_temp_bucket']),
"maxWorkers": str(default_args['sample_jdbc_dataflow_ingestion' + 'dataflow_max_workers']),
"network": str(default_args['sample_jdbc_dataflow_ingestion' + 'network']),
"subnetwork": str(default_args['sample_jdbc_dataflow_ingestion' + 'subnetwork'])}
}
}
dataflow_job_sample_jdbc_dataflow_ingestion = DataflowStartFlexTemplateOperator(
task_id="dataflow_flex_template_sample_jdbc_dataflow_ingestion",
location=default_args['sample_jdbc_dataflow_ingestion'+'dataflow_location'],
body=body
)
dataflow_job_sample_jdbc_dataflow_ingestion
sample_jdbc_dataflow_ingestion
# End thread group definition
tg_level_1_Thread_1
tg_level_1_Thread_2
# End level group definition
# Start level group definition
with TaskGroup(group_id="Level_2") as tg_Level_2:
# Start thread group definition
with TaskGroup(group_id="Level_2_Thread_3") as tg_level_2_Thread_3:
with TaskGroup(group_id="run_dataform_tag") as run_dataform_tag:
# compilation
create_compilation_result_for_run_dataform_tag = DataformCreateCompilationResultOperator(
project_id=default_args['run_dataform_tag'+'dataform_project_id'],
region=default_args['run_dataform_tag'+'dataform_location'],
repository_id=default_args['run_dataform_tag'+'repository_name'],
task_id="compilation_task_run_dataform_tag",
compilation_result={
"git_commitish": default_args['run_dataform_tag'+'branch'],
"code_compilation_config": {
"vars": {
"start_date": "{{ params.start_date_str }}",
"end_date": "{{ params.end_date_str }}"
}
}
}
)
# workflow invocation in dataform
create_workflow_run_dataform_tag_invocation = DataformCreateWorkflowInvocationOperator(
project_id=default_args['run_dataform_tag' + 'dataform_project_id'],
region=default_args['run_dataform_tag' + 'dataform_location'],
repository_id=default_args['run_dataform_tag' + 'repository_name'],
task_id='workflow_inv_run_dataform_tag',
asynchronous=True,
workflow_invocation={
"compilation_result": "{{ task_instance.xcom_pull('Level_2.Level_2_Thread_3.run_dataform_tag.compilation_task_run_dataform_tag')['name'] }}",
"invocation_config": { "included_tags": default_args['run_dataform_tag'+'tags'],
"transitive_dependencies_included": True
}
},
trigger_rule='all_success'
)
is_workflow_run_dataform_tag_invocation_done = DataformWorkflowInvocationStateSensor(
project_id=default_args['run_dataform_tag' + 'dataform_project_id'],
region=default_args['run_dataform_tag' + 'dataform_location'],
repository_id=default_args['run_dataform_tag' + 'repository_name'],
task_id="is_workflow_run_dataform_tag_invocation_done",
workflow_invocation_id=("{{ task_instance.xcom_pull('Level_2.Level_2_Thread_3.run_dataform_tag.workflow_inv_run_dataform_tag')['name'].split('/')[-1] }}"),
expected_statuses={WorkflowInvocation.State.SUCCEEDED},
failure_statuses={WorkflowInvocation.State.FAILED, WorkflowInvocation.State.CANCELLED},
)
create_compilation_result_for_run_dataform_tag >> create_workflow_run_dataform_tag_invocation >> is_workflow_run_dataform_tag_invocation_done
run_dataform_tag
# End thread group definition
tg_level_2_Thread_3
# End level group definition
end = empty.EmptyOperator(
task_id='end',
trigger_rule='all_success'
)
start >> tg_Level_1 >> tg_Level_2 >> end