storagetransfer/conftest.py (201 lines of code) (raw):

# Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Useful utilities for STS samples tests. """ import json import os import uuid from azure.storage.blob import BlobServiceClient, ContainerClient import boto3 from google.cloud import pubsub_v1, secretmanager, storage, storage_transfer from google.cloud.storage_transfer import TransferJob import pytest @pytest.fixture(scope="module") def secret_cache(): """ Cached secrets. Prevents multiple causes to Secret Manager. """ return {"aws": None, "azure": None} @pytest.fixture(scope="module") def project_id(): yield os.environ.get("GOOGLE_CLOUD_PROJECT") def retrieve_from_secret_manager(name: str): """ Retrieves a secret given a name. example ``name`` = ``projects/123/secrets/my-secret/versions/latest`` """ client = secretmanager.SecretManagerServiceClient() # retrieve from secret manager response = client.access_secret_version(request={"name": name}) # parse and cache secret from secret manager return response.payload.data.decode("UTF-8") def aws_parse_and_cache_secret_json(payload: str, secret_cache): """ Decodes a JSON string in AWS AccessKey JSON format. Supports both single-key "AccessKey" and JSON with props. ``payload`` examples: - ``{"AccessKey": {"AccessKeyId": "", "SecretAccessKey": ""}`` - ``{"AccessKeyId": "", "SecretAccessKey": ""}`` """ secret = json.loads(payload) # normalize to props as keys if secret.get("AccessKey"): secret = secret.get("AccessKey") secret_cache["aws"] = { "aws_access_key_id": secret["AccessKeyId"], "aws_secret_access_key": secret["SecretAccessKey"], } return secret_cache["aws"] def aws_key_pair(secret_cache): if secret_cache["aws"]: return secret_cache["aws"] sts_aws_secret = os.environ.get("STS_AWS_SECRET") if sts_aws_secret: return aws_parse_and_cache_secret_json(sts_aws_secret, secret_cache) sts_aws_secret_name = os.environ.get("STS_AWS_SECRET_NAME") if sts_aws_secret_name: res = retrieve_from_secret_manager(sts_aws_secret_name) return aws_parse_and_cache_secret_json(res, secret_cache) aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID") aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") return { "aws_access_key_id": aws_access_key_id, "aws_secret_access_key": aws_secret_access_key, } def azure_parse_and_cache_secret_json(payload: str, secret_cache): """ Decodes a JSON string in JSON format. Supports both single-key "AccessKey" and JSON with props. ``payload`` examples: - ``{"StorageAccount": "", "ConnectionString": "", "SAS": ""}`` """ secret = json.loads(payload) secret_cache["azure"] = { "storage_account": secret["StorageAccount"], "connection_string": secret["ConnectionString"], "sas_token": secret["SAS"], } return secret_cache["azure"] def azure_credentials(secret_cache): if secret_cache["azure"]: return secret_cache["azure"] sts_azure_secret = os.environ.get("STS_AZURE_SECRET") if sts_azure_secret: return azure_parse_and_cache_secret_json(sts_azure_secret, secret_cache) sts_azure_secret_name = os.environ.get("STS_AZURE_SECRET_NAME") if sts_azure_secret_name: res = retrieve_from_secret_manager(sts_azure_secret_name) return azure_parse_and_cache_secret_json(res, secret_cache) raise Exception( "env variables not found: 'STS_AZURE_SECRET'/'STS_AZURE_SECRET_NAME'" ) @pytest.fixture def aws_access_key_id(secret_cache): yield aws_key_pair(secret_cache)["aws_access_key_id"] @pytest.fixture def aws_secret_access_key(secret_cache): yield aws_key_pair(secret_cache)["aws_secret_access_key"] @pytest.fixture def azure_storage_account(secret_cache): yield azure_credentials(secret_cache)["storage_account"] @pytest.fixture def azure_connection_string(secret_cache): yield azure_credentials(secret_cache)["connection_string"] @pytest.fixture def azure_sas_token(secret_cache): yield azure_credentials(secret_cache)["sas_token"] @pytest.fixture def bucket_name(): yield f"sts-python-samples-test-{uuid.uuid4()}" @pytest.fixture def sts_service_account(project_id): client = storage_transfer.StorageTransferServiceClient() account = client.get_google_service_account({"project_id": project_id}) yield account.account_email @pytest.fixture def job_description_unique(project_id: str): """ Generate a unique job description. Attempts to find and delete a job with this generated description after tests are ran. """ # Create description client = storage_transfer.StorageTransferServiceClient() description = f"Storage Transfer Service Samples Test - {uuid.uuid4().hex}" yield description # Remove job based on description as the job's name isn't predetermined transfer_job_to_delete: TransferJob = None transfer_jobs = client.list_transfer_jobs( {"filter": json.dumps({"projectId": project_id})} ) for transfer_job in transfer_jobs: if transfer_job.description == description: transfer_job_to_delete = transfer_job break if ( transfer_job_to_delete and transfer_job_to_delete.status != TransferJob.Status.DELETED ): client.update_transfer_job( { "job_name": transfer_job_to_delete.name, "project_id": project_id, "transfer_job": {"status": storage_transfer.TransferJob.Status.DELETED}, } ) @pytest.fixture def aws_source_bucket(bucket_name: str, secret_cache): """ Creates an S3 bucket for testing. Empties and auto-deletes after tests are ran. """ s3_client = boto3.client("s3", **aws_key_pair(secret_cache)) s3_resource = boto3.resource("s3", **aws_key_pair(secret_cache)) try: s3_client.create_bucket(Bucket=bucket_name) except Exception: # Skip AWS tests until bucket quota limit is resolved, per team discussion. pytest.skip("Skipping due to AWS bucket restrictions and limitations.") yield bucket_name s3_resource.Bucket(bucket_name).objects.all().delete() s3_client.delete_bucket(Bucket=bucket_name) @pytest.fixture def azure_source_container(bucket_name: str, azure_connection_string: str): """ Creates an Azure container for testing. Empties and auto-deletes after tests are ran. """ service: BlobServiceClient = BlobServiceClient.from_connection_string( conn_str=azure_connection_string ) container_client: ContainerClient = service.get_container_client( container=bucket_name ) container_client.create_container() yield bucket_name container_client.delete_container() @pytest.fixture def gcs_bucket(project_id: str, bucket_name: str): """ Yields and auto-cleans up a CGS bucket for use in STS jobs """ storage_client = storage.Client(project=project_id) bucket = storage_client.create_bucket(bucket_name) yield bucket bucket.delete() @pytest.fixture def source_bucket(gcs_bucket: storage.Bucket, sts_service_account: str): """ Yields and auto-cleans up a CGS bucket preconfigured with necessary STS service account read perms """ # Setup policy for STS member: str = f"serviceAccount:{sts_service_account}" objectViewer = "roles/storage.objectViewer" bucketReader = "roles/storage.legacyBucketReader" # Prepare policy policy = gcs_bucket.get_iam_policy(requested_policy_version=3) policy.bindings.append({"role": objectViewer, "members": {member}}) policy.bindings.append({"role": bucketReader, "members": {member}}) # Set policy gcs_bucket.set_iam_policy(policy) yield gcs_bucket @pytest.fixture def destination_bucket(gcs_bucket: storage.Bucket, sts_service_account: str): """ Yields and auto-cleans up a CGS bucket preconfigured with necessary STS service account write perms """ # Setup policy for STS member: str = f"serviceAccount:{sts_service_account}" bucketWriter = "roles/storage.legacyBucketWriter" # Prepare policy policy = gcs_bucket.get_iam_policy(requested_policy_version=3) policy.bindings.append({"role": bucketWriter, "members": {member}}) # Set policy gcs_bucket.set_iam_policy(policy) yield gcs_bucket @pytest.fixture def intermediate_bucket(gcs_bucket: storage.Bucket, sts_service_account: str): """ Yields and auto-cleans up a GCS bucket preconfigured with necessary STS service account write perms """ # Setup policy for STS member: str = f"serviceAccount:{sts_service_account}" objectViewer = "roles/storage.objectViewer" bucketReader = "roles/storage.legacyBucketReader" bucketWriter = "roles/storage.legacyBucketWriter" # Prepare policy policy = gcs_bucket.get_iam_policy(requested_policy_version=3) policy.bindings.append({"role": objectViewer, "members": {member}}) policy.bindings.append({"role": bucketReader, "members": {member}}) policy.bindings.append({"role": bucketWriter, "members": {member}}) # Set policy gcs_bucket.set_iam_policy(policy) yield gcs_bucket @pytest.fixture def agent_pool_name(): """ Yields a source agent pool name """ # use default agent yield "" @pytest.fixture def posix_root_directory(): """ Yields a POSIX root directory """ # use arbitrary path yield "/my-posix-root/" @pytest.fixture def manifest_file(source_bucket: storage.Bucket): """ Yields a transfer manifest file name """ # use arbitrary path and name yield f"gs://{source_bucket.name}/test-manifest.csv" @pytest.fixture def pubsub_id(project_id: str): """ Yields a pubsub subscription ID. Deletes it afterwards """ publisher = pubsub_v1.PublisherClient() topic_id = f"pubsub-sts-topic-{uuid.uuid4()}" topic_path = publisher.topic_path(project_id, topic_id) publisher.create_topic(request={"name": topic_path}) subscriber = pubsub_v1.SubscriberClient() subscription_id = f"pubsub-sts-subscription-{uuid.uuid4()}" subscription_path = subscriber.subscription_path(project_id, subscription_id) subscription = subscriber.create_subscription( request={"name": subscription_path, "topic": topic_path} ) yield str(subscription.name) subscriber.delete_subscription(request={"subscription": subscription_path}) subscriber.close() publisher.delete_topic(request={"topic": topic_path}) @pytest.fixture def sqs_queue_arn(secret_cache): """ Yields an AWS SQS queue ARN. Deletes it afterwards. """ sqs = boto3.resource("sqs", **aws_key_pair(secret_cache), region_name="us-west-1") queue_name = f"sqs-sts-queue-{uuid.uuid4()}" queue = sqs.create_queue(QueueName=queue_name) yield queue.attributes["QueueArn"] queue.delete()