import argparse
import os
import sys
import traceback
from datetime import datetime, timedelta
from typing import Tuple

import pandas as pd

from Yugong.Ownership import Ownership
from optimizer import Query_on_DB_Table
from utility import to_seconds, human_readable_size

# Define command-line arguments
# Commonly used
parser = argparse.ArgumentParser(description="Run DB table query optimization tests.")
parser.add_argument("--test", type=str, choices=["long_term", "samplek", "reorg_unaware", "yugong"],
                    required=False, help="Specify which test to run, e.g., long_term, samplek")
parser.add_argument("--view", action="store_true", help="Print the path before real run")  # store False by default
parser.add_argument("--c", type=int, default=30, help="Portion of compute to cloud")
parser.add_argument("--k", type=float, default=1, help="Sample rate of top cost-sensitive jobs")
parser.add_argument("--num_week", type=int, default=2, help="Number of weeks to run")
parser.add_argument("--Spark", action="store_true", help="Test Spark jobs additional to Presto jobs")  # store False by default
parser.add_argument("--rep_rate", type=float, default=0.004, help="Pre-selecting replication budget rate, [0, 1]")
parser.add_argument("--rep_strategy", type=str, default="job_access_density",
                    choices=["job_access_density", "job_access_frequency", "read_traffic_volume",
                             "read_traffic_density", "inverse_dataset_size"],
                    required=False, help="Specify which replication strategy to use, job_access_density by default")

args = parser.parse_args()

day = 7
storage_gb_week = 0.023 * day / 30
egress_gb = 0.02
p_network_gb = 23.3/(100/8*3600)  # 100Gbps => 100/8*3600 GB/hr = $23.3/hr

network_capacity_gb = 8640 * day * 1024  # 800 Gbps = 100 GB/s = 8.64 PB/day * 7 days

binary = True

def read_yugong_df(start_date: datetime, end_date: datetime) -> Tuple[pd.DataFrame, str]:
    if end_date - start_date != timedelta(days=6):
        raise ValueError("The date range must be exactly 7 days, check the input")
    workload_print_info = f"{start_date.strftime('%Y%m%d')}-{end_date.strftime('%Y%m%d')} Presto Spark jobs"
    job_data_access_df = pd.read_csv(os.path.join("yugongTraces",
                                                    f"report-uown-volume-table-{start_date.strftime('%Y%m%d')}-{end_date.strftime('%Y%m%d')}.csv"),
                                        dtype={'abstractFingerPrint': str,
                                                'db_name': str,
                                                'table_name': str,
                                                'inputDataSize': float,
                                                'outputDataSize': float,
                                                'cputime': float
                                                })
    return job_data_access_df, workload_print_info

def prepare_df(start_date: datetime, end_date: datetime, Presto=True, Spark=True) -> Tuple[pd.DataFrame, str]:
    if end_date - start_date != timedelta(days=6):
        raise ValueError("The date range must be exactly 7 days, check the input")

    if not Presto and not Spark:
        raise ValueError("At least one of Presto and Spark must be True to have data")

    if end_date <= datetime.strptime("2024-05-09", "%Y-%m-%d"):
        assert Presto and not Spark, "Only Presto is available before 2024-05-09"
        job_data_access_df = pd.read_csv(os.path.join("oldTraces",
                    f"report-abFP-volume-table-{start_date.strftime('%m%d')}-{end_date.strftime('%m%d')}-all.csv"),
                     dtype = {'abstractFingerPrint': str,
                              'db_name': str,
                              'table_name': str,
                              'inputDataSize': float,
                              'cputime': str
                              })
        job_data_access_df['db_name'] = job_data_access_df['db_name'].astype(str)
        job_data_access_df['table_name'] = job_data_access_df['table_name'].astype(str)
        job_data_access_df['cputime'] = job_data_access_df['cputime'].apply(to_seconds)
        job_data_access_df['outputDataSize'] = 0
        workload_print_info = f"{start_date.strftime('%Y%m%d')}-{end_date.strftime('%Y%m%d')} Presto jobs"
    else:
        workload_print_info = f"{start_date.strftime('%Y%m%d')}-{end_date.strftime('%Y%m%d')}"
        if Presto:
            presto_job = pd.read_csv(os.path.join("newTraces", f"report-abFP-volume-table-{start_date.strftime('%Y%m%d')}-{end_date.strftime('%Y%m%d')}-Presto.csv"),
                                     dtype={'abstractFingerPrint': str,
                                            'db_name': str,
                                            'table_name': str,
                                            'inputDataSize': float,
                                            'outputDataSize': float,
                                            'cputime': float
                                            })
            presto_job['db_name'] = presto_job['db_name'].astype(str)
            presto_job['table_name'] = presto_job['table_name'].astype(str)
            workload_print_info += " Presto"
        else:
            presto_job = pd.DataFrame()
        if Spark:
            spark_job = pd.read_csv(os.path.join("newTraces", f"report-abFP-volume-table-{start_date.strftime('%Y%m%d')}-{end_date.strftime('%Y%m%d')}-Spark.csv"),
                                    dtype={'abstractFingerPrint': str,
                                           'db_name': str,
                                           'table_name': str,
                                           'inputDataSize': float,
                                           'outputDataSize': float,
                                           'cputime': float
                                           })
            spark_job['db_name'] = spark_job['db_name'].astype(str)
            spark_job['table_name'] = spark_job['table_name'].astype(str)

            total_cputime = spark_job.groupby("abstractFingerPrint")["cputime"].first().sum()
            print(f"Total cputime of Spark jobs: {total_cputime}")

            abFP_counts = spark_job['abstractFingerPrint'].value_counts()
            spark_job["cputime"] /= spark_job["abstractFingerPrint"].map(abFP_counts)
            print(f"should == Total cputime of Spark jobs after normalization: {spark_job['cputime'].sum()}")
            #assert spark_job['cputime'].sum() // 1000 == total_cputime // 1000, "Normalization error"

            workload_print_info += " Spark"
        else:
            spark_job = pd.DataFrame()
        job_data_access_df = pd.concat([presto_job, spark_job], ignore_index=True)
        workload_print_info += " jobs"
    return job_data_access_df, workload_print_info

def test_yugong(compute_on_cloud_pct: int = 30, rep_budget_rate: float = 0.004, num_of_week: int = 2):
    try:
        # Validate input
        assert compute_on_cloud_pct in [30, 50, 70], "compute_on_cloud must be one of [30, 50, 70]"

        # Set up parameters (not expected to change)
        #  - avg_bw_usage (float): Fraction of network bandwidth dedicated to Moirai on average.
        avg_bw_usage_ratio = 0.2  # empirical value
        sample_rate = 1
        output_dir = f"yugong_results"
        os.makedirs(output_dir, exist_ok=True)

        # Redirect stdout to a file
        original_stdout = sys.stdout
        sys.stdout = open(f"{output_dir}/log_c{compute_on_cloud_pct}.txt", "a")
        print(f"Time: {datetime.now()}", flush=True)

        reserved_bandwidth_gb = avg_bw_usage_ratio * network_capacity_gb

        # compute placement and storage constraints
        compute_cloud_min, compute_cloud_max = compute_on_cloud_pct / 100, compute_on_cloud_pct / 100 + 0.05
        storage_on_prem_min, storage_on_prem_max = 1 - compute_on_cloud_pct / 100 - 0.05, 1 - compute_on_cloud_pct / 100

        base_path = f"{output_dir}/test_run_c{compute_on_cloud_pct}_bw{avg_bw_usage_ratio:.2f}_local{100 - compute_on_cloud_pct}"
        last_dir = base_path  # Track last processed directory

        # Initialize graph if not in view mode (i.e., not just printing the path for sanity check)
        view_mode = args.view
        graph = None

        # header: abstractFingerPrint,db_name,table_name,inputDataSize,outputDataSize,cputime
        job_data_access_df, workload_print_info = read_yugong_df(datetime.strptime("2024-10-22", "%Y-%m-%d"),
                                                                 datetime.strptime("2024-10-28", "%Y-%m-%d"))

        job_data_access_df['totalDataSize'] = job_data_access_df['inputDataSize'] + job_data_access_df['outputDataSize']
        workload_df = job_data_access_df.groupby('abstractFingerPrint').agg({'totalDataSize': 'sum'}).reset_index()
        workload_df.sort_values('totalDataSize', ascending=False, inplace=True)
        print(f"** Workload info **")
        for abFP, totalDataSize in zip(workload_df['abstractFingerPrint'], workload_df['totalDataSize']):
            print(f"Project {abFP} has access size {human_readable_size(totalDataSize)}", flush=True)

        ownership = Ownership()
        table_df = pd.read_csv("report-table-size-20241021.csv",
                               dtype={'hive_database_name': str, 'hive_table_name': str, 'uown_names': str},
                                na_values = ['\\N'])
        table_df['table'] = table_df['hive_database_name'] + '.' + table_df['hive_table_name']
        for table, uown_names in zip(table_df['table'], table_df['uown_names']):
            if pd.isna(uown_names):  # Check for NaN values
                continue
            #print(f"Table {table} has ownership {uown_names}", flush=True)
            ownership.add_table_ownership(table, uown_names)
        table_df['project'] = table_df['table'].apply(ownership.get_table_ownership)
        merged_df = table_df.groupby('project').agg({'table': 'count', 'dir_size': 'sum'}).reset_index()
        merged_df.sort_values('dir_size', ascending=False, inplace=True)
        print(f"** Table ownership info **")
        for project, table_count, dir_size in zip(merged_df['project'], merged_df['table'], merged_df['dir_size']):
            print(f"Project {project} has {table_count} tables with total size {human_readable_size(dir_size)}", flush=True)

        rep_list = pd.read_csv(f"{output_dir}/replicated_tables_{rep_budget_rate:.3f}.csv",
                               dtype={'replicated_tables': str})['replicated_tables'].tolist()
        print(f"# of replicated tables: {len(rep_list)}")
        if not view_mode:
            graph = Query_on_DB_Table(
                job_data_access_df,
                workload_print_info,
                'report-table-size-20241021.csv',
                rep_threshold=rep_budget_rate,  # optimizer will figure out the actual budget based on the data
                k=sample_rate,
                log_dir=output_dir,
                yugong=True, # enable Yugong constraint
                ownership=ownership,
                rep_list=rep_list
            )

        if not os.path.exists(base_path):
            graph.solve_gurobi(
                egress_gb, storage_gb_week, compute_cloud_min, compute_cloud_max, reserved_bandwidth_gb,
                base_path, storage_on_prem_min, storage_on_prem_max, True,
                alpha=1, time_limit=24 * 60 * 60,  # 24 hours
                p_network_gb=p_network_gb * 5,  # TODO: Hard-coded now
            )

        # Verify the placement file
        placement_file = os.path.join(base_path, "dataset_placement.csv")
        assert os.path.exists(placement_file), f"File not found: {placement_file}"
        previous_placement = placement_file
        period_start = datetime.strptime("2024-10-29", "%Y-%m-%d")

        for week_offset in range(num_of_week):
            start_date = period_start + timedelta(weeks=week_offset)
            end_date = start_date + timedelta(days=6)
            label = start_date.strftime("%m%d")

            output_path = f"{output_dir}/test_run_c{compute_on_cloud_pct}_bw{avg_bw_usage_ratio:.2f}_local{100 - compute_on_cloud_pct}_{label}"
            if os.path.exists(output_path):
                previous_placement = os.path.join(output_path, "dataset_placement.csv")
                print(f"Skip {output_path}")
                continue

            print(f"Previous placement: {previous_placement}", flush=True)

            job_data_access_df, workload_print_info = read_yugong_df(start_date, end_date)

            if not view_mode:
                # Restore database table states from previous placement
                graph.restore_unique_db_tables(previous_placement, log_dir=last_dir)
                # Update the workload with the new access trace
                graph.update_workload(job_data_access_df, workload_print_info, log_dir=last_dir)
                # Update the previous placement
                graph.update_previous_placement(previous_placement)

            # Optimization parameters
            alpha = 1  # the degree of penalty for table switch

            print(f"Running optimization for week starting on {label}")
            print("----------------------------------------")
            print(f"Inputs: days=7, egress_gb={egress_gb}, storage_gb_week={storage_gb_week}, "
                    f"compute_cloud_min={compute_cloud_min}, compute_cloud_max={compute_cloud_max}, "
                    f"network_cap_gb={reserved_bandwidth_gb}, "
                    f"storage_on_prem_min={storage_on_prem_min}, storage_on_prem_max={storage_on_prem_max}")
            print(f"penalty degree alpha={alpha}")
            print("----------------------------------------", flush=True)

            # Solve optimization problem for this period
            if not view_mode:
                graph.solve_gurobi(
                    egress_gb, storage_gb_week, compute_cloud_min, compute_cloud_max, reserved_bandwidth_gb,
                    output_path, storage_on_prem_min, storage_on_prem_max, True,
                    alpha=alpha, time_limit=24 * 60 * 60,  # 24 hours
                    p_network_gb=p_network_gb * 5,  # TODO: Hard-coded now
                )

            # Update the previous placement for the next iteration
            previous_placement = os.path.join(output_path, "dataset_placement.csv")
            last_dir = output_path

    except Exception as e:
        print(f"Error in test_yugong with compute_on_cloud_pct={compute_on_cloud_pct}, rep_budget_rate={rep_budget_rate}")
        print("Exception traceback:")
        print(traceback.format_exc())
        raise

def test_sample_k(sample_rate: float, compute_on_cloud_pct: int = 30, test_Spark: bool = True,
                  rep_budget_rate: float = 0.004, rep_strategy: str = "job_access_density",
                    num_weeks: int = 2
                  ):
    """
    Given sample ratio, compute on cloud (%), avg bandwidth usage ratio of 800Gbps, and replication budget of total data

    Parameters:
    - sample_rate (float): Sample rate of top cost-sensitive jobs.
    - compute_on_cloud (int): Percentage of resources allocated (suggested to be in [30, 50, 70]).
    - test_Spark (bool): If True, use Spark traces from 2024-2025 (>100 days) along with Presto traces in the same period.
                         If False, use Presto traces from 2023-2024 (>200 days).
    - rep_budget (float): Replication budget constraint (percentage of total data)
    - rep_strategy (str): Selection strategy in pre-selecting process (default: "job_access_density")

    Functionality:
    1. Validates input parameters.
    2. Sets up output directories.
    3. Initializes optimization parameters.
    4. Iterates through weekly data and solves the optimization problem.
    """

    try:
        # Validate input
        assert compute_on_cloud_pct in [30, 50, 70], "compute_on_cloud must be one of [30, 50, 70]"

        # Set up parameters (not expected to change)
        #  - avg_bw_usage (float): Fraction of network bandwidth dedicated to Moirai on average.
        avg_bw_usage_ratio = 0.02 # empirical value

        # Set up directories
        output_dir = f"sample_{sample_rate:.3f}"
        os.makedirs(output_dir, exist_ok=True)

        # Redirect stdout to a file
        original_stdout = sys.stdout
        if rep_strategy != "job_access_density":
            sys.stdout = open(f"{output_dir}/log_c{compute_on_cloud_pct}_{rep_strategy}.txt", "a")
        else:
            sys.stdout = open(f"{output_dir}/log_c{compute_on_cloud_pct}.txt", "a")
        print(f"Time: {datetime.now()}", flush=True)

        reserved_bandwidth_gb = avg_bw_usage_ratio * network_capacity_gb

        # compute placement and storage constraints
        # For example, if compute_on_cloud_pct = 30, then compute_onprem [0.65, 0.7] and storage_on_prem [0.65, 0.7]
        compute_cloud_min, compute_cloud_max = compute_on_cloud_pct / 100, compute_on_cloud_pct / 100 + 0.05
        storage_on_prem_min, storage_on_prem_max = 1 - compute_on_cloud_pct / 100 - 0.05, 1 - compute_on_cloud_pct / 100

        # Initialize graph
        if rep_strategy != "job_access_density":
            base_path = f"{output_dir}/test_run_c{compute_on_cloud_pct}_bw{avg_bw_usage_ratio:.2f}_local{100 - compute_on_cloud_pct}_{rep_strategy}"
        else:
            base_path = f"{output_dir}/test_run_c{compute_on_cloud_pct}_bw{avg_bw_usage_ratio:.2f}_local{100 - compute_on_cloud_pct}"
        last_dir = base_path  # Track last processed directory

        # Initialize graph if not in view mode (i.e., not just printing the path for sanity check)
        view_mode = args.view
        graph = None

        if test_Spark:
            job_data_access_df, workload_print_info = prepare_df(datetime.strptime("2024-10-22", "%Y-%m-%d"),
                                                                 datetime.strptime("2024-10-28", "%Y-%m-%d"),
                                                                 Presto=True, Spark=True)
        else:
            job_data_access_df, workload_print_info = prepare_df(datetime.strptime("2023-09-08", "%Y-%m-%d"),
                                                                 datetime.strptime("2023-09-14", "%Y-%m-%d"),
                                                                    Presto=True, Spark=False)

        if not view_mode:
            graph = Query_on_DB_Table(
                job_data_access_df,
                workload_print_info,
                'report-table-size-0907.csv' if not test_Spark else 'report-table-size-20241021.csv',
                rep_threshold=rep_budget_rate, # optimizer will figure out the actual budget based on the data
                rep_strategy=rep_strategy,
                k=sample_rate,
                log_dir=output_dir
            )

        # Run the first optimization if not already completed
        if not os.path.exists(base_path):
            graph.solve_gurobi(
                egress_gb, storage_gb_week, compute_cloud_min, compute_cloud_max, reserved_bandwidth_gb,
                base_path, storage_on_prem_min, storage_on_prem_max, True,
                alpha=1, time_limit=30 * 24 * 60 * 60, # 30 days
                p_network_gb=p_network_gb * 5, # TODO: Hard-coded now
            )

        # Verify the placement file
        placement_file = os.path.join(base_path, "dataset_placement.csv")
        assert os.path.exists(placement_file), f"File not found: {placement_file}"
        previous_placement = placement_file

        # Define dynamic date-based traces processing
        if test_Spark:
            period_start = datetime.strptime("2024-10-29", "%Y-%m-%d")
        else:
            period_start = datetime.strptime("2023-09-15", "%Y-%m-%d")  # Start date
        # num_weeks = args.num_week  # Number weekly iterations

        for week_offset in range(num_weeks):
            start_date = period_start + timedelta(weeks=week_offset)
            end_date = start_date + timedelta(days=6)
            label = start_date.strftime("%m%d")

            if rep_strategy != "job_access_density":
                output_path = f"{output_dir}/test_run_c{compute_on_cloud_pct}_bw{avg_bw_usage_ratio:.2f}_local{100 - compute_on_cloud_pct}_{rep_strategy}_{label}"
            else:
                output_path = f"{output_dir}/test_run_c{compute_on_cloud_pct}_bw{avg_bw_usage_ratio:.2f}_local{100 - compute_on_cloud_pct}_{label}"
            if os.path.exists(output_path):
                previous_placement = os.path.join(output_path, "dataset_placement.csv")
                print(f"Skip {output_path}")
                continue

            print(f"Previous placement: {previous_placement}", flush=True)

            job_data_access_df, workload_print_info = prepare_df(start_date, end_date, Presto=True, Spark=test_Spark)

            if not view_mode:
                # Restore database table states from previous placement
                graph.restore_unique_db_tables(previous_placement, log_dir=last_dir)
                # Update the workload with the new access trace
                graph.update_workload(job_data_access_df, workload_print_info, log_dir=last_dir)
                # Update the previous placement
                graph.update_previous_placement(previous_placement)

            # Optimization parameters
            alpha = 1 # the degree of penalty for table switch

            print(f"Running optimization for week starting on {label}")
            print("----------------------------------------")
            print(f"Inputs: days=7, egress_gb={egress_gb}, storage_gb_week={storage_gb_week}, "
                  f"compute_cloud_min={compute_cloud_min}, compute_cloud_max={compute_cloud_max}, "
                  f"network_cap_gb={reserved_bandwidth_gb}, "
                  f"storage_on_prem_min={storage_on_prem_min}, storage_on_prem_max={storage_on_prem_max}")
            print(f"penalty degree alpha={alpha}")
            print("----------------------------------------", flush=True)

            # Solve optimization problem for this period
            if not view_mode:
                graph.solve_gurobi(
                    egress_gb, storage_gb_week, compute_cloud_min, compute_cloud_max, reserved_bandwidth_gb,
                    output_path, storage_on_prem_min, storage_on_prem_max, True,
                    alpha=alpha, time_limit=24 * 60 * 60,
                    p_network_gb=p_network_gb * 5,  # TODO: Hard-coded now
                )

            # Update the previous placement for the next iteration
            previous_placement = os.path.join(output_path, "dataset_placement.csv")
            last_dir = output_path

        # Close the log file
        sys.stdout.close()
        sys.stdout = original_stdout

    except Exception as e:
        print(f"Error in test_sample_k with sample_rate={sample_rate}, compute_on_cloud_pct={compute_on_cloud_pct}")
        print("Exception traceback:")
        print(traceback.format_exc())
        raise

def test_reorganization_cost_unaware(test_Spark: bool = True, view_mode: bool = False):
    """
        Baseline: reorganization cost unaware
        Run optimization separately from 10% to 90% compute on cloud in 10% increments

        Args:
            test_Spark: If True, use Spark jobs in addition to Presto jobs
                        If False, use only Presto jobs
        """

    try:
        # Set up parameters
        avg_bw_usage_ratio = 0.02  # empirical value
        sample_rate = 1
        rep_budget_rate = 0.004  # empirical value
        alpha = 0.25  # assuming 10% change in a month (still agressive)

        # Set up directories
        output_dir = f"long_term"
        os.makedirs(output_dir, exist_ok=True)

        # Redirect stdout to a file
        original_stdout = sys.stdout
        sys.stdout = open(f"{output_dir}/log_unaware.txt", "a")
        print(f"Time: {datetime.now()}", flush=True)

        reserved_bandwidth_gb = avg_bw_usage_ratio * network_capacity_gb

        if test_Spark:
            job_data_access_df, workload_print_info = prepare_df(datetime.strptime("2024-10-22", "%Y-%m-%d"),
                                                                 datetime.strptime("2024-10-28", "%Y-%m-%d"),
                                                                 Presto=True, Spark=True)
        else:
            job_data_access_df, workload_print_info = prepare_df(datetime.strptime("2023-09-08", "%Y-%m-%d"),
                                                                 datetime.strptime("2023-09-14", "%Y-%m-%d"),
                                                                 Presto=True, Spark=False)

        for compute_on_cloud_pct in range(10, 100, 10):
            compute_cloud_min, compute_cloud_max = compute_on_cloud_pct / 100, compute_on_cloud_pct / 100 + 0.05
            storage_on_prem_min, storage_on_prem_max = 1 - compute_on_cloud_pct / 100 - 0.05, 1 - compute_on_cloud_pct / 100

            # Initialize graph
            base_path = f"{output_dir}/test_run_c{compute_on_cloud_pct}_bw{avg_bw_usage_ratio:.2f}_local{100 - compute_on_cloud_pct}"

            if os.path.exists(base_path):
                print(f"Skip {base_path}")
                continue

            print(f"Running optimization for {compute_on_cloud_pct}%")
            print("----------------------------------------")
            print(f"Inputs: days=7, egress_gb={egress_gb}, storage_gb_week={storage_gb_week}, "
                  f"compute_cloud_min={compute_cloud_min}, compute_cloud_max={compute_cloud_max}, "
                  f"network_cap_gb={reserved_bandwidth_gb}, "
                  f"storage_on_prem_min={storage_on_prem_min}, storage_on_prem_max={storage_on_prem_max}")
            print(f"penalty degree alpha={alpha}")
            print("----------------------------------------", flush=True)


            if not view_mode:
                graph = Query_on_DB_Table(
                    job_data_access_df,
                    workload_print_info,
                    'report-table-size-0907.csv' if not test_Spark else 'report-table-size-20241021.csv',
                    rep_threshold=rep_budget_rate,  # optimizer will figure out the actual budget based on the data
                    k=sample_rate,
                    log_dir=output_dir
                )
                graph.solve_gurobi(
                    egress_gb, storage_gb_week, compute_cloud_min, compute_cloud_max, reserved_bandwidth_gb,
                    base_path, storage_on_prem_min, storage_on_prem_max, True,
                    alpha=alpha, time_limit=24 * 60 * 60,  # 24 hours
                    p_network_gb=p_network_gb * 5,  # TODO: Hard-coded now
                )

        # close the log file
        sys.stdout.close()
        sys.stdout = original_stdout

    except Exception as e:
        print(f"Error in test_long_term_effect")
        print("Exception traceback:")
        print(traceback.format_exc())
        raise


def test_long_term_effect(test_Spark: bool = True, view_mode: bool = False):
    """
    Test movement effects under Spark & Presto jobs
    Move from 10% to 90% compute on cloud in 10% increments

    Args:
        test_Spark: If True, use Spark jobs in addition to Presto jobs
                    If False, use only Presto jobs
    """

    try:
        # Set up parameters
        avg_bw_usage_ratio = 0.02  # empirical value
        sample_rate = 1
        rep_budget_rate = 0.004 # empirical value
        alpha = 0.25  # assuming 10% change in a month (still agressive)

        # Set up directories
        output_dir = f"long_term"
        os.makedirs(output_dir, exist_ok=True)

        # Redirect stdout to a file
        original_stdout = sys.stdout
        sys.stdout = open(f"{output_dir}/log.txt", "a")
        print(f"Time: {datetime.now()}", flush=True)

        reserved_bandwidth_gb = avg_bw_usage_ratio * network_capacity_gb

        if test_Spark:
            job_data_access_df, workload_print_info = prepare_df(datetime.strptime("2024-10-22", "%Y-%m-%d"),
                                                                 datetime.strptime("2024-10-28", "%Y-%m-%d"),
                                                                 Presto=True, Spark=True)
        else:
            job_data_access_df, workload_print_info = prepare_df(datetime.strptime("2023-09-08", "%Y-%m-%d"),
                                                                 datetime.strptime("2023-09-14", "%Y-%m-%d"),
                                                                    Presto=True, Spark=False)
        if not view_mode:
            graph = Query_on_DB_Table(
                job_data_access_df,
                workload_print_info,
                'report-table-size-0907.csv' if not test_Spark else 'report-table-size-20241021.csv',
                rep_threshold=rep_budget_rate,  # optimizer will figure out the actual budget based on the data
                k=sample_rate,
                log_dir=output_dir
            )
        else:
            graph = None

        previous_placement = None
        last_dir = None
        for compute_on_cloud_pct in range(10, 100, 10):
            compute_cloud_min, compute_cloud_max = compute_on_cloud_pct / 100, compute_on_cloud_pct / 100 + 0.05
            storage_on_prem_min, storage_on_prem_max = 1 - compute_on_cloud_pct / 100 - 0.05, 1 - compute_on_cloud_pct / 100

            # Initialize graph
            base_path = f"{output_dir}/test_run_c{compute_on_cloud_pct}_bw{avg_bw_usage_ratio:.2f}_local{100 - compute_on_cloud_pct}"
            if compute_on_cloud_pct != 10:
                base_path += "_incr"

            if os.path.exists(base_path):
                previous_placement = os.path.join(base_path, "dataset_placement.csv")
                last_dir = base_path
                print(f"Skip {base_path}")
                continue

            print(f"Previous placement: {previous_placement}", flush=True)
            print(f"last_dir: {last_dir}", flush=True)

            if previous_placement is not None and not view_mode:
                assert last_dir is not None, "last_dir must be set if previous_placement is set"
                graph.restore_unique_db_tables(previous_placement, log_dir=last_dir)
                graph.update_workload(job_data_access_df, workload_print_info, log_dir=last_dir)
                graph.update_previous_placement(previous_placement)

            print(f"Running optimization to study long-term effect (now at {compute_on_cloud_pct}%)")
            print("----------------------------------------")
            print(f"Inputs: days=7, egress_gb={egress_gb}, storage_gb_week={storage_gb_week}, "
                    f"compute_cloud_min={compute_cloud_min}, compute_cloud_max={compute_cloud_max}, "
                    f"network_cap_gb={reserved_bandwidth_gb}, "
                    f"storage_on_prem_min={storage_on_prem_min}, storage_on_prem_max={storage_on_prem_max}")
            print(f"penalty degree alpha={alpha}")
            print("----------------------------------------", flush=True)

            if not view_mode:
                graph.solve_gurobi(
                    egress_gb, storage_gb_week, compute_cloud_min, compute_cloud_max, reserved_bandwidth_gb,
                    base_path, storage_on_prem_min, storage_on_prem_max, True,
                    alpha=alpha, time_limit=24 * 60 * 60,  # 24 hours
                    p_network_gb=p_network_gb * 5,  # TODO: Hard-coded now
                )

            last_dir = base_path
            previous_placement = os.path.join(base_path, "dataset_placement.csv")

        # close the log file
        sys.stdout.close()
        sys.stdout = original_stdout

    except Exception as e:
        print(f"Error in test_long_term_effect")
        print("Exception traceback:")
        print(traceback.format_exc())
        raise


if __name__ == "__main__":
    if args.test == "samplek":
        test_sample_k(sample_rate=args.k, compute_on_cloud_pct=args.c,
                      test_Spark=args.Spark, rep_budget_rate=args.rep_rate,
                      rep_strategy=args.rep_strategy, num_weeks=args.num_week)
    elif args.test == "yugong":
        test_yugong(compute_on_cloud_pct=args.c, rep_budget_rate=args.rep_rate, num_of_week=args.num_week)
    elif args.test == "long_term":
        test_long_term_effect(test_Spark=args.Spark, view_mode=args.view)
    elif args.test == "reorg_unaware":
        test_reorganization_cost_unaware(test_Spark=args.Spark, view_mode=args.view)
    else:
        raise ValueError("Unknown test type provided.")