dag_utils/tools.py (93 lines of code) (raw):
# Copyright 2023 The Reg Reporting Blueprint Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Composer DAG to excute the Homeloan Delinquency workflow
import json
import os
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
KubernetesPodOperator)
from kubernetes.client.models import V1VolumeMount, V1Volume, V1CSIVolumeSource
# Airflow environment variables (constant at startup)
#
# Available both in the environment and as variables.
PROJECT_ID = os.getenv('AIRFLOW_VAR_PROJECT_ID')
REGION = os.getenv('AIRFLOW_VAR_REGION')
BQ_LOCATION = os.getenv('AIRFLOW_VAR_BQ_LOCATION')
GCS_DOCS_BUCKET = os.getenv('AIRFLOW_VAR_GCS_DOCS_BUCKET')
class ComposerPodOperator(KubernetesPodOperator):
def __init__(self,
# Directories to map into the DOCS gcs bucket
doc_dirs=[],
**kwargs):
# NOTE: There is a limitation to the GCS Fuse that it
# waits 30 seconds after a pod terminates.
#
# This delay is removed in the gcs-fuse-csi-driver but may not yet
# be available in Composer and GKE Autopilot:
# https://github.com/GoogleCloudPlatform/gcs-fuse-csi-driver/issues/91#issuecomment-1886185228
if doc_dirs:
# Initialize these values in kwargs
kwargs.setdefault('annotations', {})
kwargs.setdefault('volumes', [])
kwargs.setdefault('volume_mounts', [])
# Add in the required annotations
kwargs['annotations'].update({
"gke-gcsfuse/volumes": "true",
})
# Add in the docs bucket volume
kwargs['volumes'].append(V1Volume(
name="docs-bucket",
csi=V1CSIVolumeSource(
driver="gcsfuse.csi.storage.gke.io",
read_only=False,
volume_attributes={
'bucketName': GCS_DOCS_BUCKET,
'mountOptions': ','.join([
'implicit-dirs',
'file-mode=0666',
'dir-mode=0777',
]),
},
)
))
# Path in the docs bucket for the files
sub_path = ('{{ dag_run.dag_id }}/{{ task.task_id }}' +
'/{{ execution_date | ts }}')
# Add in the docs bucket volume
for doc_dir in doc_dirs:
kwargs['volume_mounts'].append(V1VolumeMount(
name="docs-bucket",
mount_path=doc_dir,
sub_path=sub_path + doc_dir,
read_only=False,
))
super().__init__(
# Always pull -- if image is updated, we need to use the latest
image_pull_policy='Always',
# See the following URL for why the config file needs to be set:
# https://cloud.google.com/composer/docs/how-to/using/using-kubernetes-pod-operator#version-5-0-0
config_file="/home/airflow/composer_kube_config",
kubernetes_conn_id="kubernetes_default",
# As per
# https://cloud.google.com/composer/docs/composer-2/use-kubernetes-pod-operator,
# use the composer-user-workloads namespace unless workload
# identity is setup.
namespace='composer-user-workloads',
# Capture all of the logs
get_logs=True,
log_events_on_failure=True,
is_delete_operator_pod=True,
**kwargs)
class DBTComposerPodOperator(ComposerPodOperator):
def __init__(self,
env_vars={},
dbt_vars=None,
doc_dirs=[],
capture_docs=True,
**kwargs):
# Set DBT_VARS environment variable if necessary
if dbt_vars:
env_vars['DBT_VARS'] = json.dumps(dbt_vars)
# Disable colours on output -- Airflow does not render it
env_vars.setdefault('DBT_USE_COLORS', 'false')
# Disable anonymous usage stats
env_vars.setdefault('DBT_SEND_ANONYMOUS_USAGE_STATS', 'false')
# Enable JSON logging (if desired)
# env_vars.setdefault('DBT_LOG_FORMAT', 'json')
# Add the general DBT environment variables
env_vars.update({
'DBT_ENV_CUSTOM_ENV_PROJECT_ID':
'{{ var.value.PROJECT_ID }}',
'DBT_ENV_CUSTOM_ENV_REGION':
'{{ var.value.REGION }}',
'DBT_ENV_CUSTOM_ENV_BQ_LOCATION':
'{{ var.value.BQ_LOCATION }}',
'DBT_ENV_CUSTOM_ENV_GCS_DOCS_BUCKET':
'{{ var.value.GCS_DOCS_BUCKET }}',
})
# Add generic Airflow environment variables
env_vars.update({
'DBT_ENV_CUSTOM_ENV_AIRFLOW_BASE_URL':
os.getenv('AIRFLOW__WEBSERVER__BASE_URL'),
'DBT_ENV_CUSTOM_ENV_AIRFLOW_CTX_TASK_ID':
'{{ task.task_id }}',
'DBT_ENV_CUSTOM_ENV_AIRFLOW_CTX_DAG_ID':
'{{ dag_run.dag_id }}',
'DBT_ENV_CUSTOM_ENV_AIRFLOW_CTX_EXECUTION_DATE':
'{{ execution_date | ts }}',
})
if capture_docs:
doc_dirs = doc_dirs + [
'/dbt/target',
'/dbt/logs',
]
super().__init__(
env_vars=env_vars,
doc_dirs=doc_dirs,
**kwargs)