cloud-composer/dags/sample-rideshare-run-data-quality.py (294 lines of code) (raw):
####################################################################################
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
####################################################################################
# Author: Adam Paternostro
# Summary: Runs a dataplex data quality job against the rideshare_lakehouse_curated.bigquery_rideshare_trip
# Polls the DQ job until it completes
# Reads the data quality results (from BigQuery) and updates the rideshare_lakehouse_curated.bigquery_rideshare_trip in the data catalog assigning a tag template
# The user can see the data quality results by searching for the rideshare_lakehouse_curated.bigquery_rideshare_trip table in data catalog
# Reads the data quality results (from BigQuery) and updates the rideshare_lakehouse_curated.bigquery_rideshare_trip [COLUMN Level] in the data catalog assigning a tag template
# The user can see the data quality results by searching for the rideshare_lakehouse_curated.bigquery_rideshare_trip table in (Curated zone) data catalog by clicking on the Schema view
#
# References:
# https://github.com/GoogleCloudPlatform/cloud-data-quality
# https://cloud.google.com/dataplex/docs/check-data-quality
# https://github.com/GoogleCloudPlatform/cloud-data-quality/blob/main/scripts/dataproc-workflow-composer/clouddq_composer_dataplex_task_job.py
from datetime import datetime, timedelta
import requests
import sys
import os
import logging
import json
import time
import airflow
from airflow.operators import bash_operator
from airflow.utils import trigger_rule
from airflow.operators.python_operator import PythonOperator
from airflow.providers.google.cloud.operators.bigquery import BigQueryCreateEmptyDatasetOperator
from airflow.providers.google.cloud.operators.dataplex import DataplexCreateTaskOperator
from google.cloud import bigquery
from google.protobuf.duration_pb2 import Duration
import google.auth
import google.auth.transport.requests
from google.cloud import datacatalog_v1
default_args = {
'owner': 'airflow',
'depends_on_past': False,
'email': None,
'email_on_failure': False,
'email_on_retry': False,
'retries': 0,
'retry_delay': timedelta(minutes=5),
'dagrun_timeout' : timedelta(minutes=60),
}
project_id = os.environ['ENV_PROJECT_ID']
rideshare_dataset_id = "rideshare_lakehouse_curated"
code_bucket_name = os.environ['ENV_CODE_BUCKET']
yaml_path = "gs://" + code_bucket_name + "/dataplex/data-quality/dataplex_data_quality_rideshare.yaml"
bigquery_region = os.environ['ENV_BIGQUERY_REGION']
thelook_dataset_id = os.environ['ENV_THELOOK_DATASET_ID']
vpc_subnet_name = os.environ['ENV_DATAPROC_SERVERLESS_SUBNET_NAME']
dataplex_region = os.environ['ENV_DATAPLEX_REGION']
service_account_to_run_dataplex = "dataproc-service-account@" + project_id + ".iam.gserviceaccount.com"
random_extension = os.environ['ENV_RANDOM_EXTENSION']
rideshare_dataplex_lake_name = "rideshare-lakehouse-" + random_extension
data_quality_dataset_id = "dataplex_data_quality"
data_quality_table_name = "data_quality_results"
DATAPLEX_PUBLIC_GCS_BUCKET_NAME = f"dataplex-clouddq-artifacts-{dataplex_region}"
CLOUDDQ_EXECUTABLE_FILE_PATH = f"gs://{DATAPLEX_PUBLIC_GCS_BUCKET_NAME}/clouddq-executable.zip"
CLOUDDQ_EXECUTABLE_HASHSUM_FILE_PATH = f"gs://{DATAPLEX_PUBLIC_GCS_BUCKET_NAME}/clouddq-executable.zip.hashsum"
FULL_TARGET_TABLE_NAME = f"{project_id}.{data_quality_dataset_id}.{data_quality_table_name}"
TRIGGER_SPEC_TYPE = "ON_DEMAND"
spark_python_script_file = f"gs://{DATAPLEX_PUBLIC_GCS_BUCKET_NAME}/clouddq_pyspark_driver.py"
# NOTE: This is case senstive for some reason
bigquery_region = bigquery_region.upper()
# https://cloud.google.com/dataplex/docs/reference/rpc/google.cloud.dataplex.v1
# https://cloud.google.com/dataplex/docs/reference/rpc/google.cloud.dataplex.v1#google.cloud.dataplex.v1.Task.InfrastructureSpec.VpcNetwork
data_quality_config = {
"spark": {
"python_script_file": spark_python_script_file,
"file_uris": [CLOUDDQ_EXECUTABLE_FILE_PATH,
CLOUDDQ_EXECUTABLE_HASHSUM_FILE_PATH,
yaml_path
],
"infrastructure_spec" : {
"vpc_network" : {
"sub_network" : vpc_subnet_name
}
},
},
"execution_spec": {
"service_account": service_account_to_run_dataplex,
"max_job_execution_lifetime" : Duration(seconds=2*60*60),
"args": {
"TASK_ARGS": f"clouddq-executable.zip, \
ALL, \
{yaml_path}, \
--gcp_project_id={project_id}, \
--gcp_region_id={bigquery_region}, \
--gcp_bq_dataset_id={data_quality_dataset_id}, \
--target_bigquery_summary_table={FULL_TARGET_TABLE_NAME}"
}
},
"trigger_spec": {
"type_": TRIGGER_SPEC_TYPE
},
"description": "CloudDQ Airflow Task"
}
# Check on the status of the job
# Call the rest API of dataplex and then get the dataproc job and then check the status of the dataproc job
def get_clouddq_task_status(task_id):
# Wait for job to start
print ("get_clouddq_task_status STARTED, sleeping for 60 seconds for jobs to start")
time.sleep(60)
# Get auth (default service account running composer worker node)
creds, project = google.auth.default()
auth_req = google.auth.transport.requests.Request() # required to acess access token
creds.refresh(auth_req)
access_token=creds.token
auth_header = {
'Accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + access_token
}
uri = f"https://dataplex.googleapis.com/v1/projects/{project_id}/locations/{dataplex_region}/lakes/{rideshare_dataplex_lake_name}/tasks/{task_id}/jobs"
serviceJob = ""
# Get the jobs
# Get the status of each job (or the first for the demo)
try:
response = requests.get(uri, headers=auth_header)
print("get_clouddq_task_status response status code: ", response.status_code)
print("get_clouddq_task_status response status text: ", response.text)
response_json = json.loads(response.text)
if response.status_code == 200:
if ("jobs" in response_json and len(response_json["jobs"]) > 0):
serviceJob = response_json["jobs"][0]["serviceJob"]
print("get_clouddq_task_status serviceJob: ", serviceJob)
else:
errorMessage = "Could not find serviceJob in REST API response"
raise Exception(errorMessage)
else:
errorMessage = "REAT API (serviceJob) response returned response.status_code: " + str(response.status_code)
raise Exception(errorMessage)
except requests.exceptions.RequestException as err:
print(err)
raise err
dataproc_job_id = serviceJob.replace(f"projects/{project_id}/locations/{dataplex_region}/batches/","")
print ("dataproc_job_id: ", dataproc_job_id)
serviceJob_uri = f"https://dataproc.googleapis.com/v1/projects/{project_id}/locations/{dataplex_region}/batches/{dataproc_job_id}"
print ("serviceJob_uri:", serviceJob_uri)
# Run for for so many interations
counter = 1
while (counter < 60):
try:
response = requests.get(serviceJob_uri, headers=auth_header)
print("get_clouddq_task_status response status code: ", response.status_code)
print("get_clouddq_task_status response status text: ", response.text)
response_json = json.loads(response.text)
if response.status_code == 200:
if ("state" in response_json):
task_status = response_json["state"]
print("get_clouddq_task_status task_status: ", task_status)
if (task_status == 'SUCCEEDED'):
return True
if (task_status == 'FAILED'
or task_status == 'CANCELLED'
or task_status == 'ABORTED'):
errorMessage = "Task failed with status of: " + task_status
raise Exception(errorMessage)
# Assuming state is RUNNING or PENDING
time.sleep(30)
else:
errorMessage = "Could not find Job State in REST API response"
raise Exception(errorMessage)
else:
errorMessage = "REAT API response returned response.status_code: " + str(response.status_code)
raise Exception(errorMessage)
except requests.exceptions.RequestException as err:
print(err)
raise err
counter = counter + 1
errorMessage = "The process (get_clouddq_task_status) run for too long. Increase the number of iterations."
raise Exception(errorMessage)
# Run a SQL query to get the consolidated table results
# Attach a tag template in data catalog at the table level for rideshare trips
# NOTE: This will overrite the template over and over (not add new one)
def attach_tag_template_to_table():
client = bigquery.Client()
query_job = client.query(f"CALL `{project_id}.{rideshare_dataset_id}.sp_demo_data_quality_table`();")
results = query_job.result() # Waits for job to complete.
for row in results:
# print("{} : {} views".format(row.url, row.view_count))
datacatalog_client = datacatalog_v1.DataCatalogClient()
resource_name = (
f"//bigquery.googleapis.com/projects/{project_id}"
f"/datasets/{rideshare_dataset_id}/tables/bigquery_rideshare_trip"
)
table_entry = datacatalog_client.lookup_entry(
request={"linked_resource": resource_name}
)
# Attach a Tag to the table.
tag = datacatalog_v1.types.Tag()
tag.template = f"projects/{project_id}/locations/{dataplex_region}/tagTemplates/table_dq_tag_template"
tag.name = "table_dq_tag_template"
tag.fields["table_name"] = datacatalog_v1.types.TagField()
tag.fields["table_name"].string_value = "bigquery_rideshare_trip"
tag.fields["record_count"] = datacatalog_v1.types.TagField()
tag.fields["record_count"].double_value = row.record_count
tag.fields["latest_execution_ts"] = datacatalog_v1.types.TagField()
tag.fields["latest_execution_ts"].timestamp_value = row.latest_execution_ts
tag.fields["columns_validated"] = datacatalog_v1.types.TagField()
tag.fields["columns_validated"].double_value = row.columns_validated
tag.fields["columns_count"] = datacatalog_v1.types.TagField()
tag.fields["columns_count"].double_value = row.columns_count
tag.fields["success_pct"] = datacatalog_v1.types.TagField()
tag.fields["success_pct"].double_value = row.success_percentage
tag.fields["failed_pct"] = datacatalog_v1.types.TagField()
tag.fields["failed_pct"].double_value = row.failed_percentage
tag.fields["invocation_id"] = datacatalog_v1.types.TagField()
tag.fields["invocation_id"].string_value = row.invocation_id
# Get the existing tempates (we need to remove the existing one if it exists (we cannot have dups))
print ("attach_tag_template_to_table table_entry.name: ", table_entry.name)
page_result = datacatalog_client.list_tags(parent=table_entry.name)
existing_name = ""
# template: "projects/data-analytics-demo-ra5migwp3l/locations/REPLACE-REGION/tagTemplates/table_dq_tag_template"
# Handle the response
for response in page_result:
print("response: ", response)
if (response.template == tag.template):
existing_name = response.name
break
# This technically will rermove the same template if we are in a loop
# We should ideally have more than 1 template for different reasons since a specific template cannot be assigned more than once to a table,
# but you can assign different templates.
if (existing_name != ""):
print(f"Delete tag: {existing_name}")
datacatalog_client.delete_tag(name=existing_name)
# https://cloud.google.com/python/docs/reference/datacatalog/latest/google.cloud.datacatalog_v1.services.data_catalog.DataCatalogClient#google_cloud_datacatalog_v1_services_data_catalog_DataCatalogClient_create_tag
tag = datacatalog_client.create_tag(parent=table_entry.name, tag=tag)
print(f"Created tag: {tag.name}")
# Run a SQL query to get the column result
# Attach a tag template in data catalog at the column level for rideshare trips
# NOTE: This will overrite the template over and over (not add new one)
def attach_tag_template_to_columns():
client = bigquery.Client()
# This should just return a single column once (this code is not meant to handle the same column twice)
# If you have the same column twice the code will overwrite the first results. You should aggregate the
# results together or apply different tag templates per result.
query_job = client.query(f"CALL `{project_id}.{rideshare_dataset_id}.sp_demo_data_quality_columns`();")
results = query_job.result() # Waits for job to complete.
for row in results:
datacatalog_client = datacatalog_v1.DataCatalogClient()
resource_name = (
f"//bigquery.googleapis.com/projects/{project_id}"
f"/datasets/{rideshare_dataset_id}/tables/bigquery_rideshare_trip"
)
table_entry = datacatalog_client.lookup_entry(
request={"linked_resource": resource_name}
)
# Attach a Tag to the table.
tag = datacatalog_v1.types.Tag()
tag.template = f"projects/{project_id}/locations/{dataplex_region}/tagTemplates/column_dq_tag_template"
tag.name = "column_dq_tag_template"
tag.column = row.column_id
tag.fields["table_name"] = datacatalog_v1.types.TagField()
tag.fields["table_name"].string_value = "bigquery_rideshare_trip"
tag.fields["invocation_id"] = datacatalog_v1.types.TagField()
tag.fields["invocation_id"].string_value = row.invocation_id
tag.fields["execution_ts"] = datacatalog_v1.types.TagField()
tag.fields["execution_ts"].timestamp_value = row.execution_ts
tag.fields["column_id"] = datacatalog_v1.types.TagField()
tag.fields["column_id"].string_value = row.column_id
tag.fields["rule_binding_id"] = datacatalog_v1.types.TagField()
tag.fields["rule_binding_id"].string_value = row.rule_binding_id
tag.fields["rule_id"] = datacatalog_v1.types.TagField()
tag.fields["rule_id"].string_value = row.rule_id
tag.fields["dimension"] = datacatalog_v1.types.TagField()
tag.fields["dimension"].string_value = row.dimension
tag.fields["rows_validated"] = datacatalog_v1.types.TagField()
tag.fields["rows_validated"].double_value = row.rows_validated
tag.fields["success_count"] = datacatalog_v1.types.TagField()
tag.fields["success_count"].double_value = row.success_count
tag.fields["success_pct"] = datacatalog_v1.types.TagField()
tag.fields["success_pct"].double_value = row.success_percentage
tag.fields["failed_count"] = datacatalog_v1.types.TagField()
tag.fields["failed_count"].double_value = row.failed_count
tag.fields["failed_pct"] = datacatalog_v1.types.TagField()
tag.fields["failed_pct"].double_value = row.failed_percentage
tag.fields["null_count"] = datacatalog_v1.types.TagField()
tag.fields["null_count"].double_value = row.null_count
tag.fields["null_pct"] = datacatalog_v1.types.TagField()
tag.fields["null_pct"].double_value = row.null_percentage
# Get the existing tempates (we need to remove the existing one if it exists (we cannot have dups))
print ("attach_tag_template_to_columns table_entry.name: ", table_entry.name)
page_result = datacatalog_client.list_tags(parent=table_entry.name)
existing_name = ""
# template: "projects/data-analytics-demo-ra5migwp3l/locations/REPLACE-REGION/tagTemplates/column_dq_tag_template"
""" Sample Response
name: "projects/data-analytics-demo-ra5migwp3l/locations/us/entryGroups/@bigquery/entries/cHJvamVjdHMvZGF0YS1hbmFseXRpY3MtZGVtby1yYTVtaWd3cDNsL2RhdGFzZXRzL3RheGlfZGF0YXNldC90YWJsZXMvdGF4aV90cmlwcw/tags/CVg1OS7dOJhY"
template: "projects/data-analytics-demo-ra5migwp3l/locations/REPLACE-REGION/tagTemplates/column_dq_tag_template"
fields {
key: "column_id"
value {
display_name: "Column Name"
string_value: "DOLocationID"
}
}
fields {
key: "dimension"
value {
display_name: "Dimension"
string_value: "INTEGRITY"
}
}
"""
# Handle the response
for response in page_result:
print("response: ", response)
# print("response.fields[column_id]: ", response.fields["column_id"])
if (response.template == tag.template and
"column_id" in response.fields and
response.fields["column_id"].string_value == tag.column):
existing_name = response.name
print(f"existing_name: {existing_name}")
break
# This technically will rermove the same template if we are in a loop
# We should ideally have more than 1 template for different reasons since a specific template cannot be assigned more than once to a column,
# but you can assign different templates.
# One odd thing is that if you call create_tag and the template exists, it will overwrite. It errors when doing this for a table though.
if (existing_name != ""):
print(f"Delete tag: {existing_name}")
datacatalog_client.delete_tag(name=existing_name)
# https://cloud.google.com/python/docs/reference/datacatalog/latest/google.cloud.datacatalog_v1.services.data_catalog.DataCatalogClient#google_cloud_datacatalog_v1_services_data_catalog_DataCatalogClient_create_tag
tag = datacatalog_client.create_tag(parent=table_entry.name, tag=tag)
print(f"Created tag: {tag.name} on {tag.column}")
with airflow.DAG('sample-rideshare-run-data-quality',
default_args=default_args,
start_date=datetime(2021, 1, 1),
# Add the Composer "Data" directory which will hold the SQL/Bash scripts for deployment
template_searchpath=['/home/airflow/gcs/data'],
# Not scheduled, trigger only
schedule_interval=None) as dag:
# NOTE: The service account of the Composer worker node must have access to run these commands
# Create the dataset for holding dataplex data quality results
# NOTE: This has to be in the same region as the BigQuery dataset we are performing our data quality checks
create_data_quality_dataset = BigQueryCreateEmptyDatasetOperator(
task_id="create_dataset",
location=bigquery_region,
project_id=project_id,
dataset_id=data_quality_dataset_id,
exists_ok=True
)
# https://airflow.apache.org/docs/apache-airflow-providers-google/stable/_api/airflow/providers/google/cloud/operators/dataplex/index.html#airflow.providers.google.cloud.operators.dataplex.DataplexCreateTaskOperator
create_dataplex_task = DataplexCreateTaskOperator(
project_id=project_id,
region=dataplex_region,
lake_id=rideshare_dataplex_lake_name,
body=data_quality_config,
dataplex_task_id="cloud-dq-{{ ts_nodash.lower() }}",
task_id="create_dataplex_task",
)
get_clouddq_task_status = PythonOperator(
task_id='get_clouddq_task_status',
python_callable= get_clouddq_task_status,
op_kwargs = { "task_id" : "cloud-dq-{{ ts_nodash.lower() }}" },
execution_timeout=timedelta(minutes=300),
dag=dag,
)
attach_tag_template_to_table = PythonOperator(
task_id='attach_tag_template_to_table',
python_callable= attach_tag_template_to_table,
execution_timeout=timedelta(minutes=5),
dag=dag,
)
attach_tag_template_to_columns = PythonOperator(
task_id='attach_tag_template_to_columns',
python_callable= attach_tag_template_to_columns,
execution_timeout=timedelta(minutes=5),
dag=dag,
)
create_data_quality_dataset >> \
create_dataplex_task >> get_clouddq_task_status >> \
attach_tag_template_to_table >> attach_tag_template_to_columns
"""
Sample dataplex output from REST API call
{
"jobs": [
{
"name": "projects/781192597639/locations/REPLACE-REGION/lakes/rideshare-data-lake-r8immkwx8o/tasks/cloud-dq-20221031t182959/jobs/de934178-cde0-489d-b16f-ac6c1e919431",
"uid": "de934178-cde0-489d-b16f-ac6c1e919431",
"startTime": "2022-10-31T18:30:40.959187Z",
"service": "DATAPROC",
"serviceJob": "projects/paternostro-9033-2022102613235/locations/REPLACE-REGION/batches/de934178-cde0-489d-b16f-ac6c1e919431-0"
}
],
"nextPageToken": "Cg5iDAiyqICbBhCQsqf7Ag"
}
"""
"""
Sample dataproc output from REST API call
curl \
'https://dataproc.googleapis.com/v1/projects/paternostro-9033-2022102613235/locations/REPLACE-REGION/batches/de934178-cde0-489d-b16f-ac6c1e919431-0?key=[YOUR_API_KEY]' \
--header 'Authorization: Bearer [YOUR_ACCESS_TOKEN]' \
--header 'Accept: application/json' \
--compressed
{
"name": "projects/paternostro-9033-2022102613235/locations/REPLACE-REGION/batches/de934178-cde0-489d-b16f-ac6c1e919431-0",
"uuid": "af6fd3a5-9ed3-4459-a13b-28d254732704",
"createTime": "2022-10-31T18:30:40.959187Z",
"pysparkBatch": {
"mainPythonFileUri": "gs://dataplex-clouddq-artifacts-REPLACE-REGION/clouddq_pyspark_driver.py",
"args": [
"clouddq-executable.zip",
"ALL",
"gs://processed-paternostro-9033-2022102613235-r8immkwx8o/dataplex/dataplex_data_quality_rideshare.yaml",
"--gcp_project_id=paternostro-9033-2022102613235",
"--gcp_region_id=US",
"--gcp_bq_dataset_id=dataplex_data_quality",
"--target_bigquery_summary_table=paternostro-9033-2022102613235.dataplex_data_quality.data_quality_results"
],
"fileUris": [
"gs://dataplex-clouddq-artifacts-REPLACE-REGION/clouddq-executable.zip",
"gs://dataplex-clouddq-artifacts-REPLACE-REGION/clouddq-executable.zip.hashsum",
"gs://processed-paternostro-9033-2022102613235-r8immkwx8o/dataplex/dataplex_data_quality_rideshare.yaml"
]
},
"runtimeInfo": {
"outputUri": "gs://dataproc-staging-REPLACE-REGION-781192597639-yjb84s0j/google-cloud-dataproc-metainfo/81e34cf5-9c71-41cc-98dd-751e70a0e1e5/jobs/srvls-batch-af6fd3a5-9ed3-4459-a13b-28d254732704/driveroutput",
"approximateUsage": {
"milliDcuSeconds": "3608000",
"shuffleStorageGbSeconds": "360800"
}
},
"state": "SUCCEEDED",
"stateTime": "2022-10-31T18:36:46.829934Z",
"creator": "service-781192597639@gcp-sa-dataplex.iam.gserviceaccount.com",
"labels": {
"goog-dataplex-task": "cloud-dq-20221031t182959",
"goog-dataplex-task-job": "de934178-cde0-489d-b16f-ac6c1e919431",
"goog-dataplex-workload": "task",
"goog-dataplex-project": "paternostro-9033-2022102613235",
"goog-dataplex-location": "REPLACE-REGION",
"goog-dataplex-lake": "rideshare-data-lake-r8immkwx8o"
},
"runtimeConfig": {
"version": "1.0",
"properties": {
"spark:spark.executor.instances": "2",
"spark:spark.driver.cores": "4",
"spark:spark.executor.cores": "4",
"spark:spark.dynamicAllocation.executorAllocationRatio": "0.3",
"spark:spark.app.name": "projects/paternostro-9033-2022102613235/locations/REPLACE-REGION/batches/de934178-cde0-489d-b16f-ac6c1e919431-0"
}
},
"environmentConfig": {
"executionConfig": {
"serviceAccount": "dataproc-service-account@paternostro-9033-2022102613235.iam.gserviceaccount.com",
"subnetworkUri": "dataproc-serverless-subnet"
},
"peripheralsConfig": {
"sparkHistoryServerConfig": {}
}
},
"operation": "projects/paternostro-9033-2022102613235/regions/REPLACE-REGION/operations/d209b731-7cdd-3db2-ba59-10c95ede9e75",
"stateHistory": [
{
"state": "PENDING",
"stateStartTime": "2022-10-31T18:30:40.959187Z"
},
{
"state": "RUNNING",
"stateStartTime": "2022-10-31T18:31:41.245940Z"
}
]
}
"""