storagetransfer/conftest.py (201 lines of code) (raw):
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Useful utilities for STS samples tests.
"""
import json
import os
import uuid
from azure.storage.blob import BlobServiceClient, ContainerClient
import boto3
from google.cloud import pubsub_v1, secretmanager, storage, storage_transfer
from google.cloud.storage_transfer import TransferJob
import pytest
@pytest.fixture(scope="module")
def secret_cache():
"""
Cached secrets. Prevents multiple causes to Secret Manager.
"""
return {"aws": None, "azure": None}
@pytest.fixture(scope="module")
def project_id():
yield os.environ.get("GOOGLE_CLOUD_PROJECT")
def retrieve_from_secret_manager(name: str):
"""
Retrieves a secret given a name.
example ``name`` = ``projects/123/secrets/my-secret/versions/latest``
"""
client = secretmanager.SecretManagerServiceClient()
# retrieve from secret manager
response = client.access_secret_version(request={"name": name})
# parse and cache secret from secret manager
return response.payload.data.decode("UTF-8")
def aws_parse_and_cache_secret_json(payload: str, secret_cache):
"""
Decodes a JSON string in AWS AccessKey JSON format.
Supports both single-key "AccessKey" and JSON with props.
``payload`` examples:
- ``{"AccessKey": {"AccessKeyId": "", "SecretAccessKey": ""}``
- ``{"AccessKeyId": "", "SecretAccessKey": ""}``
"""
secret = json.loads(payload)
# normalize to props as keys
if secret.get("AccessKey"):
secret = secret.get("AccessKey")
secret_cache["aws"] = {
"aws_access_key_id": secret["AccessKeyId"],
"aws_secret_access_key": secret["SecretAccessKey"],
}
return secret_cache["aws"]
def aws_key_pair(secret_cache):
if secret_cache["aws"]:
return secret_cache["aws"]
sts_aws_secret = os.environ.get("STS_AWS_SECRET")
if sts_aws_secret:
return aws_parse_and_cache_secret_json(sts_aws_secret, secret_cache)
sts_aws_secret_name = os.environ.get("STS_AWS_SECRET_NAME")
if sts_aws_secret_name:
res = retrieve_from_secret_manager(sts_aws_secret_name)
return aws_parse_and_cache_secret_json(res, secret_cache)
aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
return {
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
}
def azure_parse_and_cache_secret_json(payload: str, secret_cache):
"""
Decodes a JSON string in JSON format.
Supports both single-key "AccessKey" and JSON with props.
``payload`` examples:
- ``{"StorageAccount": "", "ConnectionString": "", "SAS": ""}``
"""
secret = json.loads(payload)
secret_cache["azure"] = {
"storage_account": secret["StorageAccount"],
"connection_string": secret["ConnectionString"],
"sas_token": secret["SAS"],
}
return secret_cache["azure"]
def azure_credentials(secret_cache):
if secret_cache["azure"]:
return secret_cache["azure"]
sts_azure_secret = os.environ.get("STS_AZURE_SECRET")
if sts_azure_secret:
return azure_parse_and_cache_secret_json(sts_azure_secret, secret_cache)
sts_azure_secret_name = os.environ.get("STS_AZURE_SECRET_NAME")
if sts_azure_secret_name:
res = retrieve_from_secret_manager(sts_azure_secret_name)
return azure_parse_and_cache_secret_json(res, secret_cache)
raise Exception(
"env variables not found: 'STS_AZURE_SECRET'/'STS_AZURE_SECRET_NAME'"
)
@pytest.fixture
def aws_access_key_id(secret_cache):
yield aws_key_pair(secret_cache)["aws_access_key_id"]
@pytest.fixture
def aws_secret_access_key(secret_cache):
yield aws_key_pair(secret_cache)["aws_secret_access_key"]
@pytest.fixture
def azure_storage_account(secret_cache):
yield azure_credentials(secret_cache)["storage_account"]
@pytest.fixture
def azure_connection_string(secret_cache):
yield azure_credentials(secret_cache)["connection_string"]
@pytest.fixture
def azure_sas_token(secret_cache):
yield azure_credentials(secret_cache)["sas_token"]
@pytest.fixture
def bucket_name():
yield f"sts-python-samples-test-{uuid.uuid4()}"
@pytest.fixture
def sts_service_account(project_id):
client = storage_transfer.StorageTransferServiceClient()
account = client.get_google_service_account({"project_id": project_id})
yield account.account_email
@pytest.fixture
def job_description_unique(project_id: str):
"""
Generate a unique job description. Attempts to find and delete a job with
this generated description after tests are ran.
"""
# Create description
client = storage_transfer.StorageTransferServiceClient()
description = f"Storage Transfer Service Samples Test - {uuid.uuid4().hex}"
yield description
# Remove job based on description as the job's name isn't predetermined
transfer_job_to_delete: TransferJob = None
transfer_jobs = client.list_transfer_jobs(
{"filter": json.dumps({"projectId": project_id})}
)
for transfer_job in transfer_jobs:
if transfer_job.description == description:
transfer_job_to_delete = transfer_job
break
if (
transfer_job_to_delete
and transfer_job_to_delete.status != TransferJob.Status.DELETED
):
client.update_transfer_job(
{
"job_name": transfer_job_to_delete.name,
"project_id": project_id,
"transfer_job": {"status": storage_transfer.TransferJob.Status.DELETED},
}
)
@pytest.fixture
def aws_source_bucket(bucket_name: str, secret_cache):
"""
Creates an S3 bucket for testing. Empties and auto-deletes after
tests are ran.
"""
s3_client = boto3.client("s3", **aws_key_pair(secret_cache))
s3_resource = boto3.resource("s3", **aws_key_pair(secret_cache))
try:
s3_client.create_bucket(Bucket=bucket_name)
except Exception:
# Skip AWS tests until bucket quota limit is resolved, per team discussion.
pytest.skip("Skipping due to AWS bucket restrictions and limitations.")
yield bucket_name
s3_resource.Bucket(bucket_name).objects.all().delete()
s3_client.delete_bucket(Bucket=bucket_name)
@pytest.fixture
def azure_source_container(bucket_name: str, azure_connection_string: str):
"""
Creates an Azure container for testing. Empties and auto-deletes after
tests are ran.
"""
service: BlobServiceClient = BlobServiceClient.from_connection_string(
conn_str=azure_connection_string
)
container_client: ContainerClient = service.get_container_client(
container=bucket_name
)
container_client.create_container()
yield bucket_name
container_client.delete_container()
@pytest.fixture
def gcs_bucket(project_id: str, bucket_name: str):
"""
Yields and auto-cleans up a CGS bucket for use in STS jobs
"""
storage_client = storage.Client(project=project_id)
bucket = storage_client.create_bucket(bucket_name)
yield bucket
bucket.delete()
@pytest.fixture
def source_bucket(gcs_bucket: storage.Bucket, sts_service_account: str):
"""
Yields and auto-cleans up a CGS bucket preconfigured with necessary
STS service account read perms
"""
# Setup policy for STS
member: str = f"serviceAccount:{sts_service_account}"
objectViewer = "roles/storage.objectViewer"
bucketReader = "roles/storage.legacyBucketReader"
# Prepare policy
policy = gcs_bucket.get_iam_policy(requested_policy_version=3)
policy.bindings.append({"role": objectViewer, "members": {member}})
policy.bindings.append({"role": bucketReader, "members": {member}})
# Set policy
gcs_bucket.set_iam_policy(policy)
yield gcs_bucket
@pytest.fixture
def destination_bucket(gcs_bucket: storage.Bucket, sts_service_account: str):
"""
Yields and auto-cleans up a CGS bucket preconfigured with necessary
STS service account write perms
"""
# Setup policy for STS
member: str = f"serviceAccount:{sts_service_account}"
bucketWriter = "roles/storage.legacyBucketWriter"
# Prepare policy
policy = gcs_bucket.get_iam_policy(requested_policy_version=3)
policy.bindings.append({"role": bucketWriter, "members": {member}})
# Set policy
gcs_bucket.set_iam_policy(policy)
yield gcs_bucket
@pytest.fixture
def intermediate_bucket(gcs_bucket: storage.Bucket, sts_service_account: str):
"""
Yields and auto-cleans up a GCS bucket preconfigured with necessary
STS service account write perms
"""
# Setup policy for STS
member: str = f"serviceAccount:{sts_service_account}"
objectViewer = "roles/storage.objectViewer"
bucketReader = "roles/storage.legacyBucketReader"
bucketWriter = "roles/storage.legacyBucketWriter"
# Prepare policy
policy = gcs_bucket.get_iam_policy(requested_policy_version=3)
policy.bindings.append({"role": objectViewer, "members": {member}})
policy.bindings.append({"role": bucketReader, "members": {member}})
policy.bindings.append({"role": bucketWriter, "members": {member}})
# Set policy
gcs_bucket.set_iam_policy(policy)
yield gcs_bucket
@pytest.fixture
def agent_pool_name():
"""
Yields a source agent pool name
"""
# use default agent
yield ""
@pytest.fixture
def posix_root_directory():
"""
Yields a POSIX root directory
"""
# use arbitrary path
yield "/my-posix-root/"
@pytest.fixture
def manifest_file(source_bucket: storage.Bucket):
"""
Yields a transfer manifest file name
"""
# use arbitrary path and name
yield f"gs://{source_bucket.name}/test-manifest.csv"
@pytest.fixture
def pubsub_id(project_id: str):
"""
Yields a pubsub subscription ID. Deletes it afterwards
"""
publisher = pubsub_v1.PublisherClient()
topic_id = f"pubsub-sts-topic-{uuid.uuid4()}"
topic_path = publisher.topic_path(project_id, topic_id)
publisher.create_topic(request={"name": topic_path})
subscriber = pubsub_v1.SubscriberClient()
subscription_id = f"pubsub-sts-subscription-{uuid.uuid4()}"
subscription_path = subscriber.subscription_path(project_id, subscription_id)
subscription = subscriber.create_subscription(
request={"name": subscription_path, "topic": topic_path}
)
yield str(subscription.name)
subscriber.delete_subscription(request={"subscription": subscription_path})
subscriber.close()
publisher.delete_topic(request={"topic": topic_path})
@pytest.fixture
def sqs_queue_arn(secret_cache):
"""
Yields an AWS SQS queue ARN. Deletes it afterwards.
"""
sqs = boto3.resource("sqs", **aws_key_pair(secret_cache), region_name="us-west-1")
queue_name = f"sqs-sts-queue-{uuid.uuid4()}"
queue = sqs.create_queue(QueueName=queue_name)
yield queue.attributes["QueueArn"]
queue.delete()