clouddq-migration/lib.py (48 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import random
import string
import yaml
from yaml.loader import SafeLoader
class LineNumberLoader(SafeLoader):
def construct_mapping(self, node, deep=False):
"""
Method to add line number into config
"""
mapping = super().construct_mapping(node, deep)
mapping['__line__'] = node.start_mark.line
return mapping
def validateConfigFields(config) -> None:
"""
Method to validate the Config Fields
return: None
"""
if isinstance(config, dict):
for key, value in config.items():
if value is None:
raise ValueError(f"Field '{key}' is None for the block at line {config.get('__line__')}")
validateConfigFields(value)
elif isinstance(config, list):
for item in config:
validateConfigFields(item)
def removeLineKeys(config) -> list:
"""
Method to recursively remove '__line__' keys from config.
return: config
"""
if isinstance(config, list):
return [removeLineKeys(item) for item in config]
elif isinstance(config, dict):
return {
key: removeLineKeys(value)
for key, value in config.items()
if key != '__line__'
}
else:
return config
def validateConfigFile(config_path) -> list:
"""
Method to valide the Config File
return: configs
"""
# load the config file
with open(config_path, 'r') as f:
config_file = list(yaml.load_all(f, Loader=LineNumberLoader))
# validate the config file
for config in config_file:
if not {'taskId'} <= config.keys():
raise ValueError(
"Config file must define the required config field: "
"'taskId' at line ",config.get('__line__')
)
# validate format for taskId
task = config['taskId']
validate_task(task)
# validate nested fields
validateConfigFields(config)
configs = [removeLineKeys(config) for config in config_file]
return configs
def merge_configs(config, new_config) -> dict:
""""
Method to merge configs
return: final_config
"""
# Copy original config to avoid modifying it directly
final_config = config.copy()
# Update the final config with the new config
for key, value in new_config.items():
if isinstance(value, dict) and key in final_config:
# Recursively update dictionary if the key exists and both are dictionaries
final_config[key] = merge_configs(final_config.get(key, {}), value)
else:
# Otherwise, update or add the key
final_config[key] = value
return final_config
def validate_task(task)-> None:
"""
Method to validate the format for tasks ids
"""
if not re.match(r'^[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+$', task):
raise ValueError(
f"Error: {task} does not match the expected format 'project_id.location_id.lake_id.task_id'")
def generate_config(yaml_data) -> dict:
'''
Method to generate the config of the task
'''
config = {
'dataQualitySpec': {
'rules': []
}
}
rule_bindings = yaml_data.get('rule_bindings')
entities = yaml_data.get('entities')
rules = yaml_data.get('rules')
row_filters = yaml_data.get('row_filters')
if not rule_bindings:
print('No rule binding Id present to migrate. Skipped')
else:
for rule_binding in rule_bindings:
column_id = rule_bindings[rule_binding]['column_id']
if 'entity_id' in rule_bindings[rule_binding]:
entity_id = rule_bindings[rule_binding]['entity_id']
project_id = entities[entity_id]['project_name']
dataset_id = entities[entity_id]['dataset_name']
table_id = entities[entity_id]['table_name']
if column_id in entities[entity_id]['columns'].keys():
column = entities[entity_id]['columns'][column_id]['name']
else:
column_id = column_id.upper()
column = entities[entity_id]['columns'][column_id]['name']
config['resource'] = f'//bigquery.googleapis.com/projects/{project_id}/datasets/{dataset_id}/tables/{table_id}'
else:
entity_uri = rule_bindings[rule_binding]['entity_uri']
if entity_uri.startswith('dataplex'):
print('Does not support enitity source type. ')
continue
else:
entity = entity_uri.split('/')
project_id = entity[3]
dataset_id = entity[5]
table_id = entity[7]
column = column_id
config['resource'] = f'//bigquery.googleapis.com/projects/{project_id}/datasets/{dataset_id}/tables/{table_id}'
rule_ids = rule_bindings[rule_binding]['rule_ids']
for rule_id in rule_ids:
if isinstance(rule_id, str):
if rules[rule_id]['rule_type'] == 'NOT_NULL':
config['dataQualitySpec']['rules'].append({
'dimension': 'COMPLETENESS',
'column': column,
'non_null_expectation': {},
})
elif rules[rule_id]['rule_type'] == 'REGEX':
config['dataQualitySpec']['rules'].append({
'dimension': 'ACCURACY',
'column': column,
'regex_expectation': {
"regex": rules[rule_id]['params']['pattern']
},
})
else:
custom_sql_expr = f'NOT($column = "") '
custom_sql_expr = custom_sql_expr.replace(f"$column", column)
config['dataQualitySpec']['rules'].append({
'dimension': 'COMPLETENESS',
'column': column,
'row_condition_expectation': {
"sql_expression": custom_sql_expr
},
})
else:
for key, value in rule_id.items():
rule = key
params = value
if rules[rule]['rule_type'] == 'CUSTOM_SQL_EXPR':
custom_sql_expr = rules[rule]['params']['custom_sql_expr']
for key, value in params.items():
custom_sql_expr = custom_sql_expr.replace(f"${key}", str(value))
custom_sql_expr = custom_sql_expr.replace(f"$column", column)
config['dataQualitySpec']['rules'].append({
'dimension': 'ACCURACY',
'column': column,
'row_condition_expectation': {
"sql_expression": custom_sql_expr
},
})
else:
custom_sql_statement = rules[rule]['params']['custom_sql_statement']
for key, value in params.items():
custom_sql_statement = custom_sql_statement.replace(f"${key}", str(value))
custom_sql_statement = custom_sql_statement.replace(f"$column", column)
custom_sql_statement = custom_sql_statement.replace(f"data", f'{project_id}.{dataset_id}.{table_id}')
config['dataQualitySpec']['rules'].append({
'dimension': 'ACCURACY',
'column': column,
'sql_assertion': {
"sql_statement": custom_sql_statement
},
})
return config
def generate_id() -> str:
'''
Method to generate the autoDQ task id
'''
# generate datascan id
letters_and_digits = string.ascii_lowercase + string.digits
random_string = ''.join(random.choice(letters_and_digits) for i in range(28))
datascan_id = f'auto-dq-{random_string}'
return datascan_id