import subprocess
import sys
import os
import time
import atexit
import unittest

from datetime import datetime

SPARK_HOME = "/usr/lib/spark"
sys.path.append(f'{SPARK_HOME}/python')
sys.path.append(f'{SPARK_HOME}/python/build')
sys.path.append(f'{SPARK_HOME}/python/lib/py4j-src.zip')
sys.path.append(f'{SPARK_HOME}/python/pyspark')
sys.path.append(os.path.join(sys.path[0], "sagemaker_feature_store_pyspark.zip"))

import boto3
import feature_store_pyspark
from feature_store_pyspark.FeatureStoreManager import FeatureStoreManager
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit, col
from pyspark.sql.types import Row
from datetime import datetime

# Import the required jars run the application
jars = ",".join(feature_store_pyspark.classpath_jars())
tc = unittest.TestCase()

spark = SparkSession.builder \
    .config("spark.jars", jars)\
    .config("spark.sql.sources.partitionColumnTypeInference.enabled", False)\
    .getOrCreate()

sagemaker_client = boto3.client(service_name="sagemaker")
featurestore_runtime = boto3.client(service_name="sagemaker-featurestore-runtime")
s3_client = boto3.client("s3")
caller_identity = boto3.client("sts").get_caller_identity()
account_id = caller_identity["Account"]
fraud_detection_bucket_name = "sagemaker-sample-files"
identity_file_key = "datasets/tabular/fraud_detection/synthethic_fraud_detection_SA/sampled_identity.csv"
identity_data_object = s3_client.get_object(
    Bucket=fraud_detection_bucket_name, Key=identity_file_key
)
csv_data = spark.sparkContext.parallelize(identity_data_object["Body"].read().decode("utf-8").split('\r\n'))
timestamp_suffix = time.strftime("%d-%H-%M-%S", time.gmtime())
test_feature_group_name_online_only = 'spark-test-online-only-' + timestamp_suffix
test_feature_group_name_glue_table = 'spark-test-glue-' + timestamp_suffix
test_feature_group_name_iceberg_table = 'spark-test-iceberg-' + timestamp_suffix


def clean_up(feature_group_name):
    sagemaker_client.delete_feature_group(FeatureGroupName=feature_group_name)
    print(f"Deleted feature group: {feature_group_name}")


atexit.register(clean_up, test_feature_group_name_online_only)
atexit.register(clean_up, test_feature_group_name_glue_table)
atexit.register(clean_up, test_feature_group_name_iceberg_table)

feature_store_manager = FeatureStoreManager(f"arn:aws:iam::{account_id}:role/feature-store-role")

# For testing purpose, we only get 1 record from dataset and persist it to feature store

current_timestamp = time.time()
current_date = datetime.now()
current_time = current_date.strftime('%Y-%m-%dT%H:%M:%SZ')
identity_df = spark.read.options(header='True', inferSchema='True').csv(csv_data).limit(20).cache()
identity_df = identity_df.withColumn("EventTime", lit(current_time))

feature_definitions = feature_store_manager.load_feature_definitions_from_schema(identity_df)


def wait_for_feature_group_creation_complete(feature_group_name):
    status = sagemaker_client.describe_feature_group(FeatureGroupName=feature_group_name).get("FeatureGroupStatus")
    while status == "Creating":
        print("Waiting for Feature Group Creation")
        time.sleep(5)
        status = sagemaker_client.describe_feature_group(FeatureGroupName=feature_group_name).get("FeatureGroupStatus")
    if status != "Created":
        raise RuntimeError(f"Failed to create feature group {feature_group_name}")

# Create a feature group with only online store enabled
response = sagemaker_client.create_feature_group(
    FeatureGroupName=test_feature_group_name_online_only,
    RecordIdentifierFeatureName='TransactionID',
    EventTimeFeatureName='EventTime',
    FeatureDefinitions=feature_definitions,
    OnlineStoreConfig={
        'EnableOnlineStore': True
    }
)

wait_for_feature_group_creation_complete(test_feature_group_name_online_only)

# Test1: Stream ingest to a feature group with only online store enabled
feature_store_manager.ingest_data(input_data_frame=identity_df, feature_group_arn=response.get("FeatureGroupArn"), target_stores=["OnlineStore"])

def verify_online_record(ingested_row: Row, record_dict: dict):
    ingested_row_dict = ingested_row.asDict()
    for key in ingested_row_dict.keys():
        ingested_value = ingested_row_dict.get(key, None)
        filterd_record_list = list(filter(lambda feature_value: feature_value["FeatureName"] == key, record_dict))
        if ingested_value is not None:
            filterd_record = filterd_record_list[0]
            tc.assertEqual(str(ingested_row_dict[key]), filterd_record["ValueAsString"])
        else:
            tc.assertEqual(len(filterd_record_list), 0)


for row in identity_df.collect():
    get_record_response = featurestore_runtime.get_record(
        FeatureGroupName=test_feature_group_name_online_only,
        RecordIdentifierValueAsString=str(row["TransactionID"]),
    )
    record = get_record_response["Record"]
    verify_online_record(row, record)


# Create a feature group with Glue table enabled
response = sagemaker_client.create_feature_group(
    FeatureGroupName=test_feature_group_name_glue_table,
    RecordIdentifierFeatureName='TransactionID',
    EventTimeFeatureName='EventTime',
    FeatureDefinitions=feature_definitions,
    OnlineStoreConfig={
        'EnableOnlineStore': True
    },
    OfflineStoreConfig={
        'S3StorageConfig': {
            'S3Uri': f's3://spark-test-bucket-{account_id}/test-offline-store'
        },
        'TableFormat': 'Glue'
    },
    RoleArn=f"arn:aws:iam::{account_id}:role/feature-store-role"
)

wait_for_feature_group_creation_complete(test_feature_group_name_glue_table)

# Test2: Batch ingest to offline store with glue table enabled
feature_store_manager.ingest_data(input_data_frame=identity_df, feature_group_arn=response.get("FeatureGroupArn"), target_stores=["OfflineStore"])

resolved_output_s3_uri = sagemaker_client.describe_feature_group(
    FeatureGroupName=test_feature_group_name_glue_table
).get("OfflineStoreConfig").get("S3StorageConfig").get("ResolvedOutputS3Uri")

event_time_date = datetime.fromtimestamp(current_timestamp)

partitioned_s3_path = '/'.join([resolved_output_s3_uri,
                             f"year={event_time_date.strftime('%Y')}",
                             f"month={event_time_date.strftime('%m')}",
                             f"day={event_time_date.strftime('%d')}",
                             f"hour={event_time_date.strftime('%H')}"])

offline_store_df = spark.read.format("parquet").load(partitioned_s3_path)
appended_colums = ["api_invocation_time", "write_time", "is_deleted"]

# verify the size of input DF and offline store DF are equal
tc.assertEqual(offline_store_df.count(), identity_df.count())

def verify_appended_columns(row: Row):
    tc.assertEqual(str(row["is_deleted"]), "False")
    tc.assertEqual(datetime.fromisoformat(str(row["write_time"])),
                   datetime.fromisoformat(str(row["api_invocation_time"])))

# verify the values and appeneded columns are persisted correctly
for row in identity_df.collect():
    offline_store_filtered_df = offline_store_df.filter(
        col("TransactionID").cast("string") == str(row["TransactionID"])
    )
    tc.assertTrue(offline_store_filtered_df.count() == 1)
    tc.assertEqual(offline_store_filtered_df.drop(*appended_colums).first(), row)
    verify_appended_columns(offline_store_filtered_df.first())

# Create a feature group with Iceberg table enabled
response = sagemaker_client.create_feature_group(
    FeatureGroupName=test_feature_group_name_iceberg_table,
    RecordIdentifierFeatureName='TransactionID',
    EventTimeFeatureName='EventTime',
    FeatureDefinitions=feature_definitions,
    OnlineStoreConfig={
        'EnableOnlineStore': True
    },
    OfflineStoreConfig={
        'S3StorageConfig': {
            'S3Uri': f's3://spark-test-bucket-{account_id}/test-offline-store'
        },
        'TableFormat': 'Iceberg'
    },
    RoleArn=f"arn:aws:iam::{account_id}:role/feature-store-role"
)

wait_for_feature_group_creation_complete(test_feature_group_name_iceberg_table)

# Test3: Batch ingest to offline store with ice table enabled
feature_store_manager.ingest_data(input_data_frame=identity_df, feature_group_arn=response.get("FeatureGroupArn"), target_stores=["OfflineStore"])

resolved_output_s3_uri = sagemaker_client.describe_feature_group(
    FeatureGroupName=test_feature_group_name_iceberg_table
).get("OfflineStoreConfig").get("S3StorageConfig").get("ResolvedOutputS3Uri")

s3 = boto3.client('s3')
object_listing = s3.list_objects_v2(Bucket=f'spark-test-bucket-{account_id}',
                                    Prefix=resolved_output_s3_uri.replace(f's3://spark-test-bucket-{account_id}/', '', 1))


object_list = list(filter(lambda entry: f"EventTime_trunc={event_time_date.strftime('%Y-%m-%d')}" in entry['Key'], object_listing['Contents']))
tc.assertEqual(len(object_list), 1)
offline_store_df = spark.read.format("parquet").load(f's3://spark-test-bucket-{account_id}/{object_list[0]["Key"]}')

# verify the values and appeneded columns are persisted correctly
for row in identity_df.collect():
    offline_store_filtered_df = offline_store_df.filter(
        col("TransactionID").cast("string") == str(row["TransactionID"])
    )
    tc.assertTrue(offline_store_filtered_df.count() == 1)
    tc.assertEqual(offline_store_filtered_df.drop(*appended_colums).first(), row)
    verify_appended_columns(offline_store_filtered_df.first())