pyspark-sdk/integration_test/BatchIngestionTest.py (165 lines of code) (raw):
import subprocess
import sys
import os
import time
import atexit
import unittest
from datetime import datetime
SPARK_HOME = "/usr/lib/spark"
sys.path.append(f'{SPARK_HOME}/python')
sys.path.append(f'{SPARK_HOME}/python/build')
sys.path.append(f'{SPARK_HOME}/python/lib/py4j-src.zip')
sys.path.append(f'{SPARK_HOME}/python/pyspark')
sys.path.append(os.path.join(sys.path[0], "sagemaker_feature_store_pyspark.zip"))
import boto3
import feature_store_pyspark
from feature_store_pyspark.FeatureStoreManager import FeatureStoreManager
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit, col
from pyspark.sql.types import Row
from datetime import datetime
# Import the required jars run the application
jars = ",".join(feature_store_pyspark.classpath_jars())
tc = unittest.TestCase()
spark = SparkSession.builder \
.config("spark.jars", jars)\
.config("spark.sql.sources.partitionColumnTypeInference.enabled", False)\
.getOrCreate()
sagemaker_client = boto3.client(service_name="sagemaker")
featurestore_runtime = boto3.client(service_name="sagemaker-featurestore-runtime")
s3_client = boto3.client("s3")
caller_identity = boto3.client("sts").get_caller_identity()
account_id = caller_identity["Account"]
fraud_detection_bucket_name = "sagemaker-sample-files"
identity_file_key = "datasets/tabular/fraud_detection/synthethic_fraud_detection_SA/sampled_identity.csv"
identity_data_object = s3_client.get_object(
Bucket=fraud_detection_bucket_name, Key=identity_file_key
)
csv_data = spark.sparkContext.parallelize(identity_data_object["Body"].read().decode("utf-8").split('\r\n'))
timestamp_suffix = time.strftime("%d-%H-%M-%S", time.gmtime())
test_feature_group_name_online_only = 'spark-test-online-only-' + timestamp_suffix
test_feature_group_name_glue_table = 'spark-test-glue-' + timestamp_suffix
test_feature_group_name_iceberg_table = 'spark-test-iceberg-' + timestamp_suffix
def clean_up(feature_group_name):
sagemaker_client.delete_feature_group(FeatureGroupName=feature_group_name)
print(f"Deleted feature group: {feature_group_name}")
atexit.register(clean_up, test_feature_group_name_online_only)
atexit.register(clean_up, test_feature_group_name_glue_table)
atexit.register(clean_up, test_feature_group_name_iceberg_table)
feature_store_manager = FeatureStoreManager(f"arn:aws:iam::{account_id}:role/feature-store-role")
# For testing purpose, we only get 1 record from dataset and persist it to feature store
current_timestamp = time.time()
current_date = datetime.now()
current_time = current_date.strftime('%Y-%m-%dT%H:%M:%SZ')
identity_df = spark.read.options(header='True', inferSchema='True').csv(csv_data).limit(20).cache()
identity_df = identity_df.withColumn("EventTime", lit(current_time))
feature_definitions = feature_store_manager.load_feature_definitions_from_schema(identity_df)
def wait_for_feature_group_creation_complete(feature_group_name):
status = sagemaker_client.describe_feature_group(FeatureGroupName=feature_group_name).get("FeatureGroupStatus")
while status == "Creating":
print("Waiting for Feature Group Creation")
time.sleep(5)
status = sagemaker_client.describe_feature_group(FeatureGroupName=feature_group_name).get("FeatureGroupStatus")
if status != "Created":
raise RuntimeError(f"Failed to create feature group {feature_group_name}")
# Create a feature group with only online store enabled
response = sagemaker_client.create_feature_group(
FeatureGroupName=test_feature_group_name_online_only,
RecordIdentifierFeatureName='TransactionID',
EventTimeFeatureName='EventTime',
FeatureDefinitions=feature_definitions,
OnlineStoreConfig={
'EnableOnlineStore': True
}
)
wait_for_feature_group_creation_complete(test_feature_group_name_online_only)
# Test1: Stream ingest to a feature group with only online store enabled
feature_store_manager.ingest_data(input_data_frame=identity_df, feature_group_arn=response.get("FeatureGroupArn"), target_stores=["OnlineStore"])
def verify_online_record(ingested_row: Row, record_dict: dict):
ingested_row_dict = ingested_row.asDict()
for key in ingested_row_dict.keys():
ingested_value = ingested_row_dict.get(key, None)
filterd_record_list = list(filter(lambda feature_value: feature_value["FeatureName"] == key, record_dict))
if ingested_value is not None:
filterd_record = filterd_record_list[0]
tc.assertEqual(str(ingested_row_dict[key]), filterd_record["ValueAsString"])
else:
tc.assertEqual(len(filterd_record_list), 0)
for row in identity_df.collect():
get_record_response = featurestore_runtime.get_record(
FeatureGroupName=test_feature_group_name_online_only,
RecordIdentifierValueAsString=str(row["TransactionID"]),
)
record = get_record_response["Record"]
verify_online_record(row, record)
# Create a feature group with Glue table enabled
response = sagemaker_client.create_feature_group(
FeatureGroupName=test_feature_group_name_glue_table,
RecordIdentifierFeatureName='TransactionID',
EventTimeFeatureName='EventTime',
FeatureDefinitions=feature_definitions,
OnlineStoreConfig={
'EnableOnlineStore': True
},
OfflineStoreConfig={
'S3StorageConfig': {
'S3Uri': f's3://spark-test-bucket-{account_id}/test-offline-store'
},
'TableFormat': 'Glue'
},
RoleArn=f"arn:aws:iam::{account_id}:role/feature-store-role"
)
wait_for_feature_group_creation_complete(test_feature_group_name_glue_table)
# Test2: Batch ingest to offline store with glue table enabled
feature_store_manager.ingest_data(input_data_frame=identity_df, feature_group_arn=response.get("FeatureGroupArn"), target_stores=["OfflineStore"])
resolved_output_s3_uri = sagemaker_client.describe_feature_group(
FeatureGroupName=test_feature_group_name_glue_table
).get("OfflineStoreConfig").get("S3StorageConfig").get("ResolvedOutputS3Uri")
event_time_date = datetime.fromtimestamp(current_timestamp)
partitioned_s3_path = '/'.join([resolved_output_s3_uri,
f"year={event_time_date.strftime('%Y')}",
f"month={event_time_date.strftime('%m')}",
f"day={event_time_date.strftime('%d')}",
f"hour={event_time_date.strftime('%H')}"])
offline_store_df = spark.read.format("parquet").load(partitioned_s3_path)
appended_colums = ["api_invocation_time", "write_time", "is_deleted"]
# verify the size of input DF and offline store DF are equal
tc.assertEqual(offline_store_df.count(), identity_df.count())
def verify_appended_columns(row: Row):
tc.assertEqual(str(row["is_deleted"]), "False")
tc.assertEqual(datetime.fromisoformat(str(row["write_time"])),
datetime.fromisoformat(str(row["api_invocation_time"])))
# verify the values and appeneded columns are persisted correctly
for row in identity_df.collect():
offline_store_filtered_df = offline_store_df.filter(
col("TransactionID").cast("string") == str(row["TransactionID"])
)
tc.assertTrue(offline_store_filtered_df.count() == 1)
tc.assertEqual(offline_store_filtered_df.drop(*appended_colums).first(), row)
verify_appended_columns(offline_store_filtered_df.first())
# Create a feature group with Iceberg table enabled
response = sagemaker_client.create_feature_group(
FeatureGroupName=test_feature_group_name_iceberg_table,
RecordIdentifierFeatureName='TransactionID',
EventTimeFeatureName='EventTime',
FeatureDefinitions=feature_definitions,
OnlineStoreConfig={
'EnableOnlineStore': True
},
OfflineStoreConfig={
'S3StorageConfig': {
'S3Uri': f's3://spark-test-bucket-{account_id}/test-offline-store'
},
'TableFormat': 'Iceberg'
},
RoleArn=f"arn:aws:iam::{account_id}:role/feature-store-role"
)
wait_for_feature_group_creation_complete(test_feature_group_name_iceberg_table)
# Test3: Batch ingest to offline store with ice table enabled
feature_store_manager.ingest_data(input_data_frame=identity_df, feature_group_arn=response.get("FeatureGroupArn"), target_stores=["OfflineStore"])
resolved_output_s3_uri = sagemaker_client.describe_feature_group(
FeatureGroupName=test_feature_group_name_iceberg_table
).get("OfflineStoreConfig").get("S3StorageConfig").get("ResolvedOutputS3Uri")
s3 = boto3.client('s3')
object_listing = s3.list_objects_v2(Bucket=f'spark-test-bucket-{account_id}',
Prefix=resolved_output_s3_uri.replace(f's3://spark-test-bucket-{account_id}/', '', 1))
object_list = list(filter(lambda entry: f"EventTime_trunc={event_time_date.strftime('%Y-%m-%d')}" in entry['Key'], object_listing['Contents']))
tc.assertEqual(len(object_list), 1)
offline_store_df = spark.read.format("parquet").load(f's3://spark-test-bucket-{account_id}/{object_list[0]["Key"]}')
# verify the values and appeneded columns are persisted correctly
for row in identity_df.collect():
offline_store_filtered_df = offline_store_df.filter(
col("TransactionID").cast("string") == str(row["TransactionID"])
)
tc.assertTrue(offline_store_filtered_df.count() == 1)
tc.assertEqual(offline_store_filtered_df.drop(*appended_colums).first(), row)
verify_appended_columns(offline_store_filtered_df.first())