migration/athena/athena_workgroup_migration.py (105 lines of code) (raw):

import argparse import boto3 import os import nbformat as nbf import uuid from migration.utils.datazone_helper import get_project_repo def migrate_queries(workgroup_name, domain_id, project_id, account_id, region): # Create boto3 clients with the specified region athena = boto3.client('athena', region_name=region) code_commit = boto3.client('codecommit', region_name=region) repo = get_project_repo(domain_id, project_id, region) branch = "main" # Initialize an empty list to store all named query IDs all_named_query_ids = [] # Paginate through all named queries paginator = athena.get_paginator('list_named_queries') for page in paginator.paginate(WorkGroup=workgroup_name): all_named_query_ids.extend(page['NamedQueryIds']) putFilesList = [] migration_info = [] # List to store migration information # Process each named query for query_id in all_named_query_ids: query_result = athena.get_named_query(NamedQueryId=query_id) query_name = query_result['NamedQuery']['Name'] query_string = query_result['NamedQuery']['QueryString'] # Generate a UUID for this iteration new_uuid = str(uuid.uuid4()) # Create the sqlnb file script_dir = os.path.dirname(os.path.abspath(__file__)) template_file = os.path.join(script_dir, 'template.sqlnb') nb = nbf.read(template_file, as_version=4) code_cell = nbf.v4.new_code_cell(query_string) cell_metadata = {'isLimitOn': True, 'displayMode': 'maximized', 'width': 12} code_cell['metadata'] = cell_metadata nb['cells'].append(code_cell) nb['metadata']['title'] = query_name nb['metadata']['id'] = nb['metadata']['id'].replace('<uniqueid>', new_uuid) nb['metadata']['id'] = nb['metadata']['id'].replace('<region>', region) nb['metadata']['id'] = nb['metadata']['id'].replace('<aws-account-id>', account_id) # Write the sqlnb file with open(f'{query_name}.sqlnb', 'w') as f: nbf.write(nb, f) # Add the file to putFilesList with open(f'{query_name}.sqlnb', mode='r+b') as file_obj: file_content = file_obj.read() file_path = f'athena_saved_queries/{workgroup_name}/{query_name}.sqlnb' putFileEntry = { 'filePath': file_path, 'fileContent': file_content } putFilesList.append(putFileEntry) # Store migration info migration_info.append({ 'name': query_name, 'query_id': query_id, 'path': file_path }) # Clean up the local file os.remove(f'{query_name}.sqlnb') # Perform a single commit with all files if putFilesList: parent_commit_id = code_commit.get_branch(repositoryName=repo, branchName=branch).get("branch").get("commitId") commit_response = code_commit.create_commit( repositoryName=repo, branchName=branch, parentCommitId=parent_commit_id, putFiles=putFilesList ) # Check if commit was successful if 'commitId' in commit_response: print("Migration successful. Commit ID:", commit_response['commitId']) print("\nMigrated queries:") for info in migration_info: print(f"Name: {info['name']}") print(f"Query ID: {info['query_id']}") print(f"Migrated to: {info['path']}") print("---") else: print("Migration failed. No commit was made.") else: print("No queries to migrate.") print(f"Query migration process completed. Total queries migrated: {len(all_named_query_ids)}") def bring_your_own_workgroup(workgroup_name, domain_id, project_id, account_id, region): print(f"Tagging Athena workgroup {workgroup_name} with DataZone project ID...") # Call Athena tag-resource API with the given workgroup_name athena = boto3.client('athena', region_name=region) athena.tag_resource( ResourceARN=f'arn:aws:athena:{region}:{account_id}:workgroup/{workgroup_name}', Tags=[{'Key': 'AmazonDataZoneProject', 'Value': project_id}] ) print(f"Tagged Athena workgroup {workgroup_name} with DataZone project ID.") print(f"Updating default Athena connection with workgroup {workgroup_name}...") # Call Datazone list-connections API to find the default Athena connection datazone = boto3.client('datazone', region_name=region) default_athena_connection = datazone.list_connections( domainIdentifier=domain_id, projectIdentifier=project_id, type='ATHENA' ) # Call DataZone update-connection API to update the default Athena connection with the given workgroup_name datazone.update_connection( domainIdentifier=domain_id, identifier=default_athena_connection['items'][0]['connectionId'], props={ 'athenaProperties': { 'workgroupName': workgroup_name } } ) print(f"Updated default Athena connection with workgroup {workgroup_name}.") if __name__ == "__main__": parser = argparse.ArgumentParser(description='Migrate Athena named queries to CodeCommit') parser.add_argument('--workgroup-name', type=str, required=True, help='Athena workgroup name') parser.add_argument('--domain-id', type=str, required=True, help='ID of the SageMaker Unified Studio Domain') parser.add_argument('--project-id', type=str, required=True, help='Project ID in the SageMaker Unified Studio Domain') parser.add_argument('--account-id', type=str, required=True, help='AWS account ID') parser.add_argument('--region', type=str, required=True, help='AWS region') args = parser.parse_args() migrate_queries(args.workgroup_name, args.domain_id, args.project_id, args.account_id, args.region) bring_your_own_workgroup(args.workgroup_name, args.domain_id, args.project_id, args.account_id, args.region)