clouddq-migration/dataplex.py (146 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from google.cloud import dataplex_v1 from google.cloud import storage from permission import check_bucket_permission import yaml import zipfile from io import BytesIO def get_yaml_data(source_project, source_region, lake_id, task_id): ''' Method to get the data quality yaml spec ''' # Create a client client = dataplex_v1.DataplexServiceClient() # Initialize request argument(s) request = dataplex_v1.GetTaskRequest( name=f"projects/{source_project}/locations/{source_region}/lakes/{lake_id}/tasks/{task_id}", ) # Make the request response = client.get_task(request=request) file_uri = response.spark.file_uris[-1] bucket_name = file_uri.split('/')[-2] file_name = file_uri.split('/')[-1] if not check_bucket_permission(bucket_name): raise PermissionError(f"Permission is denied on the bucket '{bucket_name}'.") if file_uri.endswith('.zip'): yaml_data = unzip_and_read_yaml(bucket_name, file_name) else: yaml_data = read_yaml_file(bucket_name, file_name) trigger_spec = response.trigger_spec return yaml_data, trigger_spec def unzip_and_read_yaml(bucket_name, zip_file_name) -> dict: ''' Method to unzip and read the dq yaml spec files ''' # Initialize a client for Google Cloud Storage storage_client = storage.Client() # Get the bucket and the blob (zip file) bucket = storage_client.bucket(bucket_name) blob = bucket.blob(zip_file_name) # Download the zip file's contents into memory zip_data = BytesIO(blob.download_as_bytes()) # Variable to store all YAML data yaml_data = {} # Open the zip file from memory with zipfile.ZipFile(zip_data, 'r') as zip_ref: # Iterate over each file in the zip archive for file_name in zip_ref.namelist(): # Only process .yaml or .yml files if file_name.endswith('.yaml') or file_name.endswith('.yml'): # Read the content of the YAML file with zip_ref.open(file_name) as file: file_data = yaml.safe_load(file) for key, value in file_data.items(): if key in yaml_data.keys(): value.update(yaml_data[key]) yaml_data[key] = value else: yaml_data.update(file_data) return yaml_data def read_yaml_file(bucket_name, file_name) -> dict: ''' Method to read the yaml file from the storage bucket ''' # Initialize a client client = storage.Client() # Get the bucket bucket = client.get_bucket(bucket_name) # Get the blob (file) blob = bucket.blob(file_name) # Download the contents of the file content = blob.download_as_string() yaml_data = yaml.safe_load(content) return yaml_data def convert_config_to_payload(config): ''' Method to convert a config into payload ''' # Initialize request argument(s) data_scan = dataplex_v1.DataScan() data_scan.data.resource = config.get('resource') data_scan.data.entity = config.get('entity') if 'description' in config: data_scan.description = config['description'] if 'displayName' in config: data_scan.display_name = config['displayName'] if 'labels' in config: data_scan.labels = config['labels'] if 'samplingPercent' in config['dataQualitySpec']: data_scan.data_quality_spec.sampling_percent = config['dataQualitySpec']['samplingPercent'] else: data_scan.data_quality_spec.sampling_percent = 10 data_scan.data_quality_spec.rules = config['dataQualitySpec']['rules'] if 'rowFilter' in config['dataQualitySpec']: data_scan.data_quality_spec.row_filter = config['dataQualitySpec']['rowFilter'] if 'postScanActions' in config['dataQualitySpec']: data_scan.data_quality_spec.post_scan_actions.bigquery_export.results_table = config['dataQualitySpec']['postScanActions']['bigqueryExport']['resultsTable'] if 'executionSpec' in config and 'trigger' in config['executionSpec'] and 'schedule' in config['executionSpec']['trigger']: data_scan.execution_spec.trigger.schedule.cron = config['executionSpec']['trigger']['schedule']['cron'] else: data_scan.execution_spec.trigger.on_demand = {} if 'executionSpec' in config and 'incrementalField' in config['executionSpec']: data_scan.execution_spec.field = config['executionSpec']['incrementalField'] return data_scan def create_datascan(gcp_project_id, location_id, datascan_id, datascan): ''' Method to create a data scan ''' try: # Create a client client = dataplex_v1.DataScanServiceClient() request = dataplex_v1.CreateDataScanRequest( parent=f"projects/{gcp_project_id}/locations/{location_id}", data_scan=datascan, data_scan_id=datascan_id, ) print(f'Creating Datascan: {datascan_id}') # Make the request operation = client.create_data_scan(request=request) response = operation.result() return response except Exception as error: print(f'Error: Failed to create {datascan_id}. ') print(error) return None def list_lakes(gcp_project_id, region_id) -> list: ''' Method to list lakes ''' try: # Create a client client = dataplex_v1.DataplexServiceClient() # Initialize request argument(s) request = dataplex_v1.ListLakesRequest( parent=f"projects/{gcp_project_id}/locations/{region_id}", ) # Make the request page_result = client.list_lakes(request=request) # Handle the response lakes = [] for response in page_result: lakes.append(response.name.split('/')[-1]) return lakes except Exception as error: print(error) return None def list_tasks(gcp_project_id, region_id, lake_id)-> dict: ''' Method to get the CloudDQ tasks ''' try: # Create a client client = dataplex_v1.DataplexServiceClient() # Initialize request argument(s) request = dataplex_v1.ListTasksRequest( parent=f"projects/{gcp_project_id}/locations/{region_id}/lakes/{lake_id}", ) # Make the request page_result = client.list_tasks(request=request) # Handle the response tasks = {} for response in page_result: task_id = response.name.split('/')[-1] if response.spark.file_uris: file_uri = response.spark.file_uris[-1] trigger_spec = response.trigger_spec tasks[task_id] = { 'file_uri': file_uri, 'trigger_spec': trigger_spec } return tasks except Exception as error: print(error) return None