import os
from datetime import datetime, timedelta

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.ticker import FuncFormatter
from pandas import date_range

from utility import human_readable_size, parse_size

plt.rcParams.update({'font.size': 20})

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

# Default colors for Matplotlib
colors_default = plt.rcParams['axes.prop_cycle'].by_key()['color']
replication_cost_color = colors_default[0]  # Blue
egress_cost_color = colors_default[1]       # Orange
network_cost_color = colors_default[2]       # Green
egress_traffic_color = colors_default[5]     # Red
ingress_traffic_color = colors_default[4]     # Purple

# Hatching patterns
hatch_patterns = ['/', '\\', '|', '-', '+', 'x', 'o', 'O', '.', '*']

# Add mode for scale
scale_mode = "linear"

font_size = 20


def overall_stats(df: pd.DataFrame, tag: str):
    """Compute overall statistics for a given dataset.

    Args:
        df (pd.DataFrame): Input dataframe with necessary columns.
        tag (str): Name identifier (e.g., "Moirai") for the results.

    Returns:
        pd.DataFrame: DataFrame containing computed metrics.
    """
    df = df[df['mode'] == 'size-predict'].copy()

    # Compute total job ingress/egress volume
    df['job_ingress_bytes'] = df['ingress_byte_Presto'] + df['ingress_byte_Spark']
    df['job_egress_bytes'] = df['egress_byte_Presto'] + df['egress_byte_Spark']

    # Compute total ingress and egress volume (including movement)
    df['ingress_volume'] = df['job_ingress_bytes'] + df['movement_ingress_bytes']
    df['egress_volume'] = df['job_egress_bytes'] + df['movement_egress_bytes']
    df['traffic_volume'] = df['ingress_volume'] + df['egress_volume']

    # Compute cost components
    df['ingress_volume_Spark'] = df['ingress_byte_Spark'] + df['movement_ingress_bytes']
    df['ingress_volume_Presto'] = df['ingress_byte_Presto'] + df['movement_ingress_bytes']
    df['egress_volume_Spark'] = df['egress_byte_Spark'] + df['movement_egress_bytes']
    df['egress_volume_Presto'] = df['egress_byte_Presto'] + df['movement_egress_bytes']
    df['cost'] = df['egress_volume_Presto'] / 1024 ** 3 * 0.02 + \
                          df['egress_volume_Spark'] / 1024 ** 3 * 0.02 + \
                            df['rep_bytes'] / 1024 ** 3 * 0.023 / 4

    results = []

    for c in df['cloud_computation_target'].unique():
        df_c = df[df['cloud_computation_target'] == c]

        # Compute costs
        network_cost = df_c['P95_traffic_bps'].max() / (100 * 1024 ** 3) * 23.3 * 24 * 7
        egress_cost_presto = df_c['egress_volume_Presto'].mean() / 1024 ** 3 * 0.02
        egress_cost_spark = df_c['egress_volume_Spark'].mean() / 1024 ** 3 * 0.02
        rep_cost = df_c['rep_bytes'].mean() / 1024 ** 3 * 0.023 / 4
        # total_cost = network_cost + egress_cost_spark + egress_cost_presto + rep_cost
        total_cost = df_c['cost'].mean() + network_cost

        # Compute standard deviation for total cost (variance)
        total_cost_std = df_c['cost'].std()

        # Append results
        results.append({
            "tag": tag,
            "c": c,
            "network_cost": network_cost,
            "egress_cost_Spark": egress_cost_spark,
            "egress_cost_Presto": egress_cost_presto,
            "egress_cost": egress_cost_spark + egress_cost_presto,
            "rep_cost": rep_cost,
            "total_cost": total_cost,
            "total_cost_std": total_cost_std,  # Add standard deviation column
            "ingress_volume": df_c['ingress_volume'].mean() / 1024 ** 4,
            "egress_volume": df_c['egress_volume'].mean() / 1024 ** 4,
        })

    return pd.DataFrame(results)

def draw_overall_new(front: bool = False,
                     job: bool = False,
                     pr: bool = False):
    assert not front or not pr or not job, "Only one of front, job, and pr can be True"
    todo_dfs = []

    # header: tag, c, network_cost, egress_cost_Spark, egress_cost_Presto, rep_cost,
    # ingress_volume_Spark, ingress_volume_Presto, egress_volume_Spark, egress_volume_Presto
    baseline_df = pd.read_csv(f'../baselines_done/log.csv')
    if front:
        suffix = "_front"
        todo_dfs.append(overall_stats(pd.read_csv(f'../sample_1.000_rep0.002/log.csv'), "Moirai\n(Our)"))
        todo_dfs.append(overall_stats(pd.read_csv(f'../yugong_results_rep0.000/log.csv'), "Yugong\n(Alibaba)"))
        todo_dfs.append(overall_stats(baseline_df[baseline_df['tag'] == "No\nRep"], "No Rep\n(Spotify)"))
        todo_dfs.append(overall_stats(baseline_df[baseline_df['tag'] == "3M\n21%"], "Rep 3Mon.\n(Twitter)"))
    elif job:
        suffix = "_job"
        todo_dfs.append(overall_stats(baseline_df[baseline_df['tag'] == "Volley\n2.5%"], "Volley\nRepTop2.5%"))
        todo_dfs.append(overall_stats(baseline_df[baseline_df['tag'] == "MoiJob\n0.2%"], "Moi-\nJobDist"))
        todo_dfs.append(overall_stats(pd.read_csv(f'../sample_1.000_rep0.002/log.csv'), "Moirai"))
    elif pr:
        suffix = "_pr"
        todo_dfs.append(overall_stats(baseline_df[baseline_df['tag'] == "MoiJob\n0.2%"], "Moi\nJobDist"))
        for rate in [0.001, 0.002, 0.004]:
            df = pd.read_csv(f'../sample_1.000_rep{rate:.3f}/log.csv')
            todo_dfs.append(overall_stats(df, f"Moi\nPR{rate * 100:.1f}%"))
        for sample_rate in [0.010, 0.050]:  # 0.001,
            df = pd.read_csv(f'../sample_{sample_rate:.3f}/log.csv')
            todo_dfs.append(overall_stats(df, f"Moi\n{sample_rate * 100:.0f}%Job"))
    else:
        suffix = ""
        todo_dfs.append(overall_stats(pd.read_csv(f'../sample_1.000_rep0.002/log.csv'), "Moirai"))
        todo_dfs.append(overall_stats(pd.read_csv(f'../yugong_results_rep0.000/log.csv'), "Yugong"))
        todo_dfs.append(overall_stats(baseline_df[baseline_df['tag'] == "Volley\n0%"], "Volley"))
        todo_dfs.append(overall_stats(baseline_df[baseline_df['tag'] == "RTD\n2.5%"], "Rep\nTop2.5%"))
        todo_dfs.append(overall_stats(baseline_df[baseline_df['tag'] == "No\nRep"], "No\nRep"))
        todo_dfs.append(overall_stats(baseline_df[baseline_df['tag'] == "3M\n21%"], "Rep\n3Mon."))
        #todo_dfs.append(overall_stats(baseline_df[baseline_df['tag'] == "Volley\n2.5%"], "Volley\nATD"))


    df = pd.concat(todo_dfs)
    df.to_csv('overall_stats_new.csv', index=False)

    for c in df['c'].unique():
        if front:
            df_c = df[df['c'] == c].set_index('tag').loc[["No Rep\n(Spotify)", "Rep 3Mon.\n(Twitter)", "Yugong\n(Alibaba)", "Moirai\n(Our)"]]  # Ensure order
        elif job:
            df_c = df[df['c'] == c].set_index('tag').loc[["Volley\nRepTop2.5%", "Moi-\nJobDist", "Moirai"]]
        elif pr:
            df_c = df[df['c'] == c].set_index('tag').loc[["Moi\nJobDist", "Moi\nPR0.1%", "Moi\nPR0.2%", "Moi\nPR0.4%", "Moi\n1%Job", "Moi\n5%Job"]] # "Volley\nATD",
        else:
            df_c = df[df['c'] == c].set_index('tag').loc[["No\nRep", "Volley", "Rep\n3Mon.", "Rep\nTop2.5%", "Yugong", "Moirai"]]  # Ensure order "Volley\nATD",
        print(df_c)

        # Create subplots

        if front:
            fig, ax1 = plt.subplots(1, 1, figsize=(6, 4.5), constrained_layout=True)
        else:
            fig, (ax2, ax1) = plt.subplots(1, 2, figsize=(11, 4), constrained_layout=True)

        # ---- PLOT 1: Cost Breakdown ---- #
        df_costs = df_c[['egress_cost', 'rep_cost', 'network_cost']]
        # df_costs.set_index('tag', inplace=True)
        df_costs.plot(kind='bar', stacked=True, ax=ax1, color=[
            replication_cost_color, egress_cost_color, network_cost_color])

        # Apply hatch patterns to the bars
        for bar, hatch in zip(ax1.containers, hatch_patterns[:len(df_costs.columns)]):
            for patch in bar.patches:
                patch.set_hatch(hatch)

        # Add total sum as a single number on top of each bar
        for idx, rects in enumerate(zip(*ax1.containers)):  # Stacked bars
            total_height = sum(rect.get_height() for rect in rects)
            if total_height > 0:
                ax1.text(rects[0].get_x() + rects[0].get_width() / 2, total_height,
                         f'{total_height / 1000:.0f}K' if total_height < 1000000 else f'{total_height / 1000**2:.1f}M',
                         ha='center', va='bottom', fontsize=font_size - 2, color='black')
        if not front and not pr and not job:
            for idx, tag in enumerate(["No\nRep", "Volley", "Rep\n3Mon.", "Rep\nTop2.5%", "Volley\nRepTop2.5%"]):
                if tag in df_costs.index:
                    total_cost = df_c.loc[tag, "total_cost"]
                    total_cost_std = df_c.loc[tag, "total_cost_std"]
                    ax1.errorbar(x=idx, y=total_cost, yerr=total_cost_std, color='black', capsize=5,
                                label="Std Dev" if idx == 0 else "")

        ax1.set_ylabel("Weekly Cost ($)", fontsize=font_size)
        ax1.set_xlabel(None)
        ax1.tick_params(rotation=0, labelsize=font_size - 2)
        if not front:
            ax1.set_xticklabels(df_costs.index, fontsize=font_size - 5, rotation=15)
        else:
            ax1.set_xticklabels(df_costs.index, fontsize=font_size - 3, rotation=0)

        if pr or job:
            yticks = [0, 20 * 1000, 40 * 1000, 60 * 1000, 80 * 1000, 100 * 1000, 120 * 1000]
            ytick_labels = ["0", "20K", "40K", "60K", "80K", "100K", "120K"]
        else:
            yticks = [0, 300 * 1000, 600 * 1000, 900 * 1000, 1200 * 1000, 1500 * 1000]
            ytick_labels = ["0", "300K", "600K", "900K", "1200K", "1500K"]

        ax1.set_yticks(yticks)
        ax1.set_yticklabels(ytick_labels, fontsize=font_size - 2)
        if c == 30 or front:
            ax1.legend(["Egress", "Replication", "Network"], fontsize=font_size - 2, ncol=1) # , loc='upper center'
        else:
            ax1.get_legend().remove()
        ax1.grid(axis='y')

        # ---- PLOT 2: Traffic Breakdown ---- #
        if not front:
            df_traffic = df_c[
                ['ingress_volume', 'egress_volume']]
            # df_traffic.set_index('tag', inplace=True)
            df_traffic.plot(kind='bar', stacked=True, ax=ax2, color=[
                ingress_traffic_color, egress_traffic_color])

            # Apply hatch patterns to the traffic bars
            for bar, hatch in zip(ax2.containers, hatch_patterns[len(df_costs.columns):]):
                for patch in bar.patches:
                    patch.set_hatch(hatch)

            # Add total sum as a single number on top of each bar
            for idx, rects in enumerate(zip(*ax2.containers)):  # Stacked bars
                total_height = sum(rect.get_height() for rect in rects)
                if total_height > 1024:
                    ax2.text(rects[0].get_x() + rects[0].get_width() / 2, total_height,
                             f'{total_height / 1024:.1f}PB', ha='center', va='bottom', fontsize=font_size - 6,
                             color='black')
                else:
                    ax2.text(rects[0].get_x() + rects[0].get_width() / 2, total_height,
                             f'{total_height:.0f}TB', ha='center', va='bottom', fontsize=font_size - 6,
                             color='black')

            ax2.set_ylabel("Weekly Traffic", fontsize=font_size)
            ax2.set_xlabel(None)
            ax2.tick_params(rotation=0, labelsize=font_size - 2)
            ax2.set_xticklabels(df_costs.index, fontsize=font_size - 5, rotation=15)
            if pr or job:
                yticks = [0, 2 * 1024, 4 * 1024, 6 * 1024, 8 * 1024]
                ytick_labels = ["0", "2PB", "4PB", "6PB", "8PB"]
            else:
                yticks = [0, 30 * 1024, 60 * 1024, 90 * 1024, 120 * 1024]
                ytick_labels = ["0", "30PB", "60PB", "90PB", "120PB"]

            ax2.set_yticks(yticks)
            ax2.set_yticklabels(ytick_labels, fontsize=font_size - 2)
            if c == 30:
                ax2.legend(["Ingress Volume", "Egress Volume"], fontsize=font_size - 3, ncol=1)
            else:
                ax2.get_legend().remove()
            ax2.grid(axis='y')

        # title
        # if not front:
        #     fig.suptitle(f"On-premises:Cloud {100 - c}%:{c}%", fontsize=font_size + 2)

        # Save the figure
        plt.savefig(f'overall_comparison_c_{c}{suffix}.pdf')
        plt.close()
        print(f"Saved overall_comparison_c_{c}{suffix}.pdf")

def draw_job_routing():
    def process(df: pd.DataFrame, tag: str):
        # header: period,mode,cloud_computation_ratio,cloud_computation_target,
        # ingress_byte_Presto,egress_byte_Presto,ingress_byte_Spark,egress_byte_Spark,
        # P90_traffic_bps,P95_traffic_bps,P99_traffic_bps,
        # movement_ingress_bytes,movement_egress_bytes,rep_bytes,sample_rate
        df['egress_volume'] = (df['egress_byte_Presto'] + df['egress_byte_Spark'] +
                                 df['movement_egress_bytes'])
        df['egress_cost'] = df['egress_volume'] / 1024 ** 3 * 0.02
        df['traffic_volume'] = (df['ingress_byte_Presto'] + df['ingress_byte_Spark'] +
                                df['egress_volume'] + df['movement_ingress_bytes'])
        df['tag'] = tag
        df = df[['tag', 'cloud_computation_target', 'traffic_volume', 'egress_cost', 'mode']]

        return df

    colors = ['darkorange', 'blue', 'dodgerblue', 'cyan']
    medianprops = dict(linestyle='-', linewidth=2, color='gold')

    # header: mode,cloud_computation_target,traffic_volume,egress_cost
    df_moirai = process(pd.read_csv('../sample_1.000_rep0.002/log.csv'), "Moirai")
    df_yugong = process(pd.read_csv('../yugong_results_rep0.002/log.csv'), "Yugong")

    # fig(a): weekly traffic volume
    fig, axes = plt.subplots(1, 3, figsize=(12, 5.5), sharey=True)
    for idx, c in enumerate([30, 50, 70]):
        ax = axes[idx]
        df_c = pd.concat([df_moirai[df_moirai['cloud_computation_target'] == c],
                          df_yugong[df_yugong['cloud_computation_target'] == c]])

        box_data = []
        box_data.append(df_c[df_c['tag'] == 'Yugong']['traffic_volume']) # Yugong
        for mode in ['independent', 'size-unaware','size-predict', 'size-aware']:
            box_data.append(df_c[(df_c['mode'] == mode) & (df_c['tag'] == 'Moirai')]['traffic_volume'])

        positions = [1, 2.2, 2.5, 2.9, 3.3]
        bp = ax.boxplot(box_data, patch_artist=True, positions=positions, widths=0.2, showfliers=False,
                        showmeans=True, whis=[10, 90], medianprops=medianprops)

        # Set boxplot colors
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)

        # Set title and labels
        ax.set_title(f"On-prem:Cloud\n{100 - c}%:{c}%", fontsize=font_size-2)
        ax.set_xlabel(None)
        ax.set_yscale('log', base=2)
        ax.set_xticks([1, 2.7])
        ax.set_xticklabels(['Yugong', 'Moirai'], fontsize=font_size - 2)
        ax.set_yticks([2** 40 * 2 ** i for i in [6, 7, 8, 9, 10, 13, 15, 16, 17]])
        ax.set_yticklabels([f"{2 ** i:.0f}TB" if i < 10 else f"{2 ** i / 1024:.0f}PB"
                                   for i in [6, 7, 8, 9, 10, 13, 15, 16, 17]], fontsize=font_size - 2)
        ax.axhline(y=11.5 * 1024**5, color='red', linestyle='--', linewidth=1.4)
        if c == 30:
            ax.text(0.06, 10.5 * 1024**5, '11.5PB', ha='center', fontsize=font_size - 2, color='red', rotation=0)
            ax.text(2.5, 13 * 1024**5, 'Network threshold', ha='center', fontsize=font_size - 6, color='red', rotation=0)
            ax.set_ylabel("Weekly Traffic Volume (log)")

        ax.tick_params(axis='x', labelsize=font_size - 2)
        ax.grid(axis='y')

        ax.text(2.2, np.percentile(box_data[1], 90) * 1.07, 'Indep', ha='center', fontsize=font_size - 6, color='black')
        ax.text(1.9, np.percentile(box_data[2], 90) * 0.6, 'Size\nUnaware', ha='center', fontsize=font_size - 6,)

        ax.text(2.95, np.percentile(box_data[3], 90) * 1.1, 'Size\nPredict', ha='center', fontsize=font_size - 6, color='black')
        ax.text(3.3, np.percentile(box_data[4], 40) * 0.3, 'Size\nOracular', ha='center', fontsize=font_size - 6, color='black')
    plt.tight_layout()
    plt.savefig('routing_traffic.pdf')
    plt.close()

    # fig(b): weekly egress cost
    fig, axes = plt.subplots(1, 3, figsize=(12, 5.5), sharey=True)
    for idx, c in enumerate([30, 50, 70]):
        ax = axes[idx]
        df_c = pd.concat([df_moirai[df_moirai['cloud_computation_target'] == c],
                            df_yugong[df_yugong['cloud_computation_target'] == c]])
        box_data = []
        box_data.append(df_c[df_c['tag'] == 'Yugong']['egress_cost'])  # Yugong
        for mode in ['independent', 'size-unaware', 'size-predict', 'size-aware']:
            box_data.append(df_c[(df_c['mode'] == mode) & (df_c['tag'] == 'Moirai')]['egress_cost'])

        positions = [1, 2.2, 2.5, 2.9, 3.3]
        bp = ax.boxplot(box_data, patch_artist=True, positions=positions, widths=0.2, showfliers=False,
                        showmeans=True, whis=[10, 90], medianprops=medianprops)

        # Set boxplot colors
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)

        # Set title and labels
        ax.set_title(f"On-prem:Cloud\n{100 - c}%:{c}%", fontsize=font_size - 2)
        ax.set_xlabel(None)
        ax.set_yscale('log', base=2)
        ax.set_xticks([1, 2.7])
        ax.set_xticklabels(['Yugong', 'Moirai'], fontsize=font_size - 2)
        ax.set_yticks([200, 500, 2000, 5000, 50000, 500000, 5000000])
        ax.set_yticklabels(["$200", "$500", "$2K", "$5K", "$50K", "$500K", "$5M"],
                           fontsize=font_size - 2)
        if c == 30:
            ax.set_ylabel("Weekly Egress Cost (log)")

        ax.tick_params(axis='x', labelsize=font_size - 2)
        ax.grid(axis='y')

        ax.text(2.2, np.percentile(box_data[1], 90) * 1.07, 'Indep', ha='center', fontsize=font_size - 6, color='black')
        ax.text(1.9, np.percentile(box_data[2], 40), 'Size\nUnaware', ha='center', fontsize=font_size - 6, )

        ax.text(2.95, np.percentile(box_data[3], 90) * 1.1, 'Size\nPredict', ha='center', fontsize=font_size - 6,
                color='black')
        ax.text(3.3, np.percentile(box_data[4], 40) * 0.3, 'Size\nOracular', ha='center', fontsize=font_size - 6,
                color='black')
    plt.tight_layout()
    plt.savefig('routing_cost.pdf')
    plt.close()


def draw_traffic_rate(single=True):
    def traffic_rate_stats(df: pd.DataFrame, tag):
        # header: period,mode,cloud_computation_ratio,cloud_computation_target,
        # ingress_byte_Presto,egress_byte_Presto,ingress_byte_Spark,egress_byte_Spark,
        # P90_traffic_bps,P95_traffic_bps,P99_traffic_bps,
        # movement_ingress_bytes,movement_egress_bytes,rep_bytes,sample_rate
        df = df[df['mode'] == 'size-predict'].copy()
        df['traffic_bytes'] = (df['ingress_byte_Presto'] + df['ingress_byte_Spark'] +
                                 df['egress_byte_Presto'] + df['egress_byte_Spark'] +
                                    df['movement_ingress_bytes'] + df['movement_egress_bytes'])
        df['avg_traffic_bps'] = df['traffic_bytes'] * 8 / 7 / 24 / 3600

        # Extract start date from period (YYYYMMDD format)
        df['start_date'] = df['period'].str[:8]  # Extract YYYYMMDD
        df['start_date'] = pd.to_datetime(df['start_date'], format='%Y%m%d')  # Convert to datetime

        # Calculate week_id based on 2024-10-22 as the reference date
        reference_date = datetime(2024, 10, 22)
        df['week_id'] = ((df['start_date'] - reference_date).dt.days // 7 + 1).astype(int)  # Compute week index

        df['tag'] = tag
        df.rename(columns={'P90_traffic_bps': 'P90', 'P95_traffic_bps': 'P95', 'P99_traffic_bps': 'P99'}, inplace=True)
        df['P90'] = df['P90'] / 1024 ** 3
        df['P95'] = df['P95'] / 1024 ** 3
        df['P99'] = df['P99'] / 1024 ** 3
        df['avg_traffic_bps'] = df['avg_traffic_bps'] / 1024 ** 3

        return df[['tag', 'week_id', 'cloud_computation_target', 'avg_traffic_bps',
                   'P90', 'P95', 'P99']]


    colors = {
        "Yugong": colors_default[0],
        "Moirai": colors_default[1]
    }

    markers = {
        'Yugong': 's',
        'Moirai': '*'
    }

    todo_dfs = []
    #for rate in [0.001, 0.002, 0.004]:
    for rate in [0.002]:
        df = pd.read_csv(f'../sample_1.000_rep{rate:.3f}/log.csv')
        todo_dfs.append(traffic_rate_stats(df, f"Moirai"))
    for rep_rate in [0.002]:
        df = pd.read_csv(f'../yugong_results_rep{rep_rate:.3f}/log.csv')
        todo_dfs.append(traffic_rate_stats(df, f"Yugong"))

    # Concatenate all processed data
    # header: period,cloud_computation_target,avg_traffic_bps,P90_traffic_bps,P95_traffic_bps,P99_traffic_bps
    df = pd.concat(todo_dfs)
    df.to_csv('traffic_rate_stats.csv', index=False)

    for metric in ['P90', 'P95', 'P99']:
        if metric == 'P90':
            ylim = 600
        elif metric == 'P95':
            ylim = 900
        else:
            ylim = 1500
        if single:
            fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
            c_list = [50]
        else:
            # Create a figure with 3 subplots (1 row, 3 columns)
            fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(22, 7))
            c_list = [30, 50, 70]

        for idx, c in enumerate(c_list):
            df_c = df[df['cloud_computation_target'] == c]
            if single:
                ax = axes
            else:
                ax = axes[idx]

            for tag in ['Yugong', 'Moirai']:
                sub_df = df_c[df_c['tag'] == tag].copy()
                sub_df.sort_values(by='week_id', inplace=True)

                # Plot the PXX traffic and the average traffic for comparison
                ax.plot(sub_df['week_id'], sub_df['avg_traffic_bps'], linestyle='--',
                        label=f'{tag} Avg', color=colors[tag],
                        marker=markers[tag], markersize=8)
                ax.plot(sub_df['week_id'], sub_df[metric], linestyle='-',
                        label=f'{tag} {metric}', color=colors[tag],
                        marker=markers[tag], markersize=8)
            # Set title and labels
            if not single:
                ax.set_title(f"On-prem:Cloud={100 - c}%:{c}%")
            ax.set_xlabel('Week')
            ax.axhline(y=160, color='red', linestyle='--', linewidth=2)
            if idx == 0:
                ax.set_ylabel('Traffic Rate (Gbps)')
                ax.legend(fontsize=font_size - 3, ncol=2)
                ax.text(0.65, 120, '160', color='red', ha='center', fontsize=20)
                ax.text(5, 110, 'Network threshold', color='red', ha='center', fontsize=20)
            else:
                ax.set_ylabel(None)
                ax.set_yticklabels([])

            ax.set_ylim(0, ylim)
            ax.grid(axis='y')

            # Create a secondary y-axis for cost
            ax2 = ax.twinx()

            ax2.set_ylim(0, ylim / 100 * 7 * 24 * 23.3)
            for tag in ['Yugong', 'Moirai']:
                sub_df = df_c[df_c['tag'] == tag].copy()
                sub_df.sort_values(by='week_id', inplace=True)
                ax2.plot(sub_df['week_id'], sub_df[metric] / 100 * 7 * 24 * 23.3,
                         linestyle='-', label=f'{tag} Cost', color=colors[tag],
                         marker=markers[tag], markersize=8)
            if idx == len(c_list)-1:
                yticks = ax2.get_yticks()
                ytick_labels = [f"{int(i/1000)}K" for i in yticks]
                ax2.set_yticklabels(ytick_labels)
                ax2.set_ylabel('Weekly Network Cost ($)')
            else:
                ax2.set_yticklabels([])


        # Save the figure
        plt.tight_layout()
        plt.savefig(f'traffic_rate_{metric}.pdf', bbox_inches='tight')
        plt.close()
        print(f"Saved traffic_rate_{metric}.pdf")

def replication_effects():
    # header: abstractFingerPrint,db_name,table_name,inputDataSize,cputime,outputDataSize
    # presto_df = pd.read_csv('../newTraces/report-abFP-volume-table-20241022-20241028-Presto.csv')
    # presto_df['db_table'] = presto_df['db_name'] + '.' + presto_df['table_name']
    # spark_df = pd.read_csv('../newTraces/report-abFP-volume-table-20241022-20241028-Spark.csv')
    # spark_df['db_table'] = spark_df['db_name'] + '.' + spark_df['table_name']
    presto_df = pd.read_csv('../newTraces/report-abFP-volume-table-20250114-20250120-Presto.csv')
    presto_df['db_table'] = presto_df['db_name'] + '.' + presto_df['table_name']
    spark_df = pd.read_csv('../newTraces/report-abFP-volume-table-20250114-20250120-Spark.csv')
    spark_df['db_table'] = spark_df['db_name'] + '.' + spark_df['table_name']

    for rep_rate in [0.02, 0.002]:
        print(f'Rep: {rep_rate}')
        for strategy in [
            # 'read_traffic_volume','inverse_dataset_size',
            #              'job_access_frequency',
            #                 'read_traffic_density',
                         'job_access_density'
                         ]:
            path = f"../sample_1.000_rep{rep_rate:.3f}_strategies/replicated_tables_{str(rep_rate)}_{strategy}.csv"
            if not os.path.exists(path):
                continue
            rep_list = pd.read_csv(path)[f'replicated_tables'].to_list()
            effective_presto_df = presto_df[~presto_df['db_table'].isin(rep_list)]
            effective_spark_df = spark_df[~spark_df['db_table'].isin(rep_list)]

            print(f"Strategy: {strategy}")
            #print("Presto # of edges, all:", len(presto_df), "effective:", len(presto_df) - len(reduced_presto_df))
            #print("Spark # of edges, all:", len(spark_df), "effective:", len(spark_df) - len(reduced_spark_df))
            print("# effective edges", len(effective_spark_df) + len(effective_presto_df))
            print("# of effective jobs", effective_spark_df['abstractFingerPrint'].nunique() + effective_presto_df['abstractFingerPrint'].nunique())
            print("# of unique db_tables", pd.concat([effective_spark_df, effective_presto_df])['db_table'].nunique())



    for rep_rate in [0.001, 0.002, 0.004]:
        rep_list = pd.read_csv(f"../sample_1.000_rep{rep_rate:.3f}/replicated_tables.csv")['replicated_tables'].to_list()
        reduced_presto_df = presto_df[presto_df['db_table'].isin(rep_list)]
        reduced_spark_df = spark_df[spark_df['db_table'].isin(rep_list)]
        print(f"Replication rate: {rep_rate:.3f}")
        print("Presto # of edges, all:", len(presto_df), "affected:", len(reduced_presto_df))
        print("Spark # of edges, all:", len(spark_df), "affected:", len(reduced_spark_df))

    # header: abstractFingerPrint,db_name,table_name,inputDataSize,outputDataSize,cputime
    yugong_df = pd.read_csv('../yugongTraces/report-uown-volume-table-20241022-20241028.csv')
    yugong_df['db_table'] = yugong_df['db_name'] + '.' + yugong_df['table_name']
    for rep_rate in [0.004]:
        rep_list = pd.read_csv(f"../yugong_results_rep{rep_rate:.3f}/replicated_tables_0.004.csv")['replicated_tables'].to_list()
        reduced_yugong_df = yugong_df[yugong_df['db_table'].isin(rep_list)]
        print(f"Replication rate: {rep_rate:.3f}")
        print("Yugong # of edges, all:", len(yugong_df), "affected:", len(reduced_yugong_df))

def verify_traffic_rate(yugong: bool = False):
    traffic_rate = 0
    start_date = datetime(year=2024, month=10, day=29) + timedelta(days=7*8)
    print("Start date:", start_date)
    for date in date_range(start=start_date, end=start_date + timedelta(days=6), freq='D'):
        if yugong:
            df = pd.read_csv(f'../yugong_results_rep0.002/c30/traffic_{date.strftime("%Y%m%d")}.csv')
        else:
            df = pd.read_csv(f'../sample_1.000_rep0.001/c30/traffic_{date.strftime("%Y%m%d")}.csv')
        df['traffic_rate'] = df['egress_rate_presto_bps'] + df['egress_rate_spark_bps'] + \
                                df['ingress_rate_presto_bps'] + df['ingress_rate_spark_bps']
        traffic_rate += df['traffic_rate'].sum()
    weekly_traffic = traffic_rate / 8 * 60
    print("Weekly traffic:",
          human_readable_size(weekly_traffic))

    all_traffic_rates = []
    for single_date in pd.date_range(start=start_date, end=start_date + timedelta(days=6), freq='D'):
        if yugong:
            traffic_file = os.path.join("../yugong_results_rep0.002/c30", f"traffic_{single_date.strftime('%Y%m%d')}.csv")
        else:
            traffic_file = os.path.join("../sample_1.000_rep0.001/c30", f"traffic_{single_date.strftime('%Y%m%d')}.csv")
        if os.path.exists(traffic_file):
            df = pd.read_csv(traffic_file)
            df['egress_rate_bps'] = df['egress_rate_presto_bps'] + df['egress_rate_spark_bps']
            df['ingress_rate_bps'] = df['ingress_rate_presto_bps'] + df['ingress_rate_spark_bps']
            df['traffic_rate'] = df['egress_rate_bps'] + df['ingress_rate_bps']
            all_traffic_rates.extend(df["traffic_rate"].tolist())

            # all_traffic_rates.extend(df["egress_rate_bps"].tolist())
            # all_traffic_rates.extend(df["ingress_rate_bps"].tolist())
        else:
            print(f"Traffic file not found: {traffic_file}")
    print("# traffic rates:", len(all_traffic_rates))
    print("P90", int(np.percentile(all_traffic_rates, 90)),
          "P95", int(np.percentile(all_traffic_rates, 95)),
            "P99", int(np.percentile(all_traffic_rates, 99)),)

    # # Convert to NumPy array and sort for CDF
    # all_traffic_rates = np.array(all_traffic_rates)
    # sorted_data = np.sort(all_traffic_rates)
    # cdf = np.arange(1, len(sorted_data) + 1) / len(sorted_data)
    #
    # # Plot CDF
    # plt.figure(figsize=(8, 6))
    # plt.plot(sorted_data, cdf, marker='.', linestyle='none')
    # plt.xlabel('Traffic Rate (bps)')
    # plt.ylabel('CDF')
    # plt.title('CDF of Traffic Rates')
    # plt.grid(True)
    # plt.xscale('log')
    # plt.savefig('traffic_rate_cdf.png')

def plot_weekly_traffic(week_id: int, yugong: bool = False, c: int = 30):
    # Compute start date of the given week (week_id=2 starts on 2024-10-22)
    base_date = datetime(2024, 10, 29)
    start_date = base_date + timedelta(days=7 * (week_id - 2))
    print("Start date:", start_date)

    traffic_rates = []

    # Iterate through the 7 days of the given week
    for i in range(7):
        current_date = start_date + timedelta(days=i)
        file_date_str = current_date.strftime('%Y%m%d')

        # Construct file path
        if yugong:
            traffic_file = f"../yugong_results_rep0.002/c{c}/traffic_{file_date_str}.csv"
        else:
            traffic_file = f"../sample_1.000_rep0.001/c{c}/traffic_{file_date_str}.csv"

        # Read traffic data if the file exists
        if os.path.exists(traffic_file):
            df = pd.read_csv(traffic_file)
            df['traffic_rate'] = (df['egress_rate_presto_bps'] + df['egress_rate_spark_bps'] +
                                  df['ingress_rate_presto_bps'] + df['ingress_rate_spark_bps'])
            traffic_rates.extend(df['traffic_rate'].tolist())
        else:
            print(f"Traffic file not found: {traffic_file}")

    # Plot the traffic rate over the week
    plt.figure(figsize=(10, 5))
    plt.plot(traffic_rates, linestyle='-', color='blue')
    plt.xlabel("Minute bucket")
    plt.ylabel("Traffic Rate (bps)")
    plt.ylim(0, 2*10**12)
    plt.title(f"Traffic Rate Over Week {week_id}")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'traffic_week_{week_id}.png')

def draw_growth(format="png"):
    def compute_manual_linear_slope(x, y):
        """Computes the slope of the best-fit line using least squares regression."""
        n = len(x)
        mean_x, mean_y = np.mean(x), np.mean(y)
        numerator = np.sum((x - mean_x) * (y - mean_y))
        denominator = np.sum((x - mean_x) ** 2)
        slope = numerator / denominator if denominator != 0 else 0
        return slope

    # Function to compute and plot linear fit manually
    def plot_manual_linear_fit(x, y, label, ax, color):
        """Computes and plots a manual least squares linear fit."""
        # Compute slope and intercept manually
        slope = compute_manual_linear_slope(x, y)
        intercept = np.mean(y) - slope * np.mean(x)

        # Generate fitted line
        linear_fit = slope * x + intercept

        ax.plot(x, linear_fit, linestyle="-", label=label, color=color)
        print(label, "slope:", slope)

    # Function to format tick labels
    def format_ticks(value, _):
        if value >= 1e12:
            return f"{value / 1e12:.1f}T"
        elif value >= 1e9:
            return f"{value / 1e9:.1f}B"
        elif value >= 1e6:
            return f"{value / 1e6:.1f}M"
        elif value >= 1e3:
            return f"{value / 1e3:.0f}K"
        else:
            return str(int(value))

    # Load Presto and Spark data
    presto_df = pd.read_csv("../metrics_per_day_presto.csv", parse_dates=['date'])
    spark_df = pd.read_csv("../metrics_per_day_spark.csv", parse_dates=['date'])

    # # Compute totals
    # presto_total_jobs = presto_df['daily_jobs'].sum()
    # spark_total_jobs = spark_df['daily_jobs'].sum()
    #
    # presto_total_rw_bytes = presto_df['daily_read_volume'].sum()
    # spark_total_rw_bytes = spark_df['daily_read_volume'].sum() + spark_df['daily_write_volume'].sum()
    #
    # print("=== Totals Over the Period ===")
    # print(f"Presto: {presto_total_jobs:,} jobs, {presto_total_rw_bytes / 1024 ** 5:.2f} PB read")
    # print(f"Spark: {spark_total_jobs:,} jobs, {spark_total_rw_bytes / 1024 ** 5:.2f} PB read+write")


    # Sort by date
    presto_df.sort_values('date', inplace=True)
    spark_df.sort_values('date', inplace=True)

    # Compute days elapsed since the first date
    min_date = min(presto_df['date'].min(), spark_df['date'].min())
    presto_df['days_elapsed'] = (presto_df['date'] - min_date).dt.days
    spark_df['days_elapsed'] = (spark_df['date'] - min_date).dt.days

    # Plot 1: Number of daily jobs
    plt.figure(figsize=(10, 4))
    plt.tick_params(axis='both', labelsize=font_size - 2)
    plt.plot(presto_df['days_elapsed'], presto_df['daily_jobs'], label="Presto", color='blue', linestyle='--')
    plot_manual_linear_fit(presto_df['days_elapsed'].values, presto_df['daily_jobs'].values, "Trend line", plt.gca(),
                           'blue')
    plt.plot(spark_df['days_elapsed'], spark_df['daily_jobs'], label="Spark", color='orange', linestyle='--')
    plot_manual_linear_fit(spark_df['days_elapsed'].values, spark_df['daily_jobs'].values, "Trend line", plt.gca(), 'orange')
    plt.text(50, 280*1000, "30% annual increase", fontsize=font_size - 2, color='red')
    plt.annotate('', xy=(58, 395 * 1000), xytext=(55, 310 * 1000),
                 arrowprops=dict(arrowstyle="->", color='red', lw=2))

    plt.annotate('', xy=(60, 240 * 1000), xytext=(55, 275 * 1000),
                 arrowprops=dict(arrowstyle="->", color='red', lw=2))
    plt.xlabel("Day", fontsize=font_size)
    plt.ylabel("# of Daily Jobs", fontsize=font_size)
    plt.legend(fontsize=font_size - 3, ncol=4, bbox_to_anchor=(0.5, 1.2), loc='upper center')

    plt.gca().yaxis.set_major_formatter(FuncFormatter(format_ticks))
    plt.ylim(bottom=0, top=600 * 1000)
    plt.xlim(0, 110)
    plt.grid()
    plt.tight_layout()
    plt.subplots_adjust(top=0.8)  # Adjust to leave space for the legend above
    plt.savefig(f"daily_jobs.{format}", bbox_inches='tight')


    # Plot 2: Number of daily templates
    plt.figure(figsize=(10, 4))
    plt.tick_params(axis='both', labelsize=font_size - 2)
    plt.plot(presto_df['days_elapsed'], presto_df['daily_templates'], label="Presto", color='blue', linestyle='--')
    plt.plot(spark_df['days_elapsed'], spark_df['daily_templates'], label="Spark", color='orange', linestyle='--')
    plot_manual_linear_fit(presto_df['days_elapsed'].values, presto_df['daily_templates'].values, "Trend line", plt.gca(), 'blue')
    plot_manual_linear_fit(spark_df['days_elapsed'].values, spark_df['daily_templates'].values, "Trend line", plt.gca(), 'orange')
    plt.text(50, 50*1000, "20% annual increase", fontsize=font_size - 2, color='red')
    plt.annotate('', xy=(58, 105 * 1000), xytext=(55, 60 * 1000),
                 arrowprops=dict(arrowstyle="->", color='red', lw=2))

    plt.annotate('', xy=(60, 35 * 1000), xytext=(55, 46 * 1000),
                 arrowprops=dict(arrowstyle="->", color='red', lw=2))

    plt.xlabel("Day", fontsize=font_size)
    plt.ylabel("# of Daily Templates", fontsize=font_size)
    plt.legend(fontsize=font_size - 3, ncol=4, bbox_to_anchor=(0.5, 1.2), loc='upper center')
    plt.gca().yaxis.set_major_formatter(FuncFormatter(format_ticks))
    plt.ylim(bottom=0, top=150 * 1000)
    plt.xlim(0, 110)
    plt.grid()
    plt.tight_layout()
    plt.savefig(f"daily_templates.{format}", bbox_inches='tight')

    # Plot 3: Daily traffic volume (read/write)
    plt.figure(figsize=(10, 3.5))
    plt.tick_params(axis='both', labelsize=font_size - 2)
    plt.plot(presto_df['days_elapsed'], presto_df['daily_read_volume'] / 1024 ** 5, label="Presto", color='blue', linestyle='--')
    plot_manual_linear_fit(presto_df['days_elapsed'].values, presto_df['daily_read_volume'].values / 1024 ** 5, "Trend line", plt.gca(), 'blue')
    plt.plot(spark_df['days_elapsed'], (spark_df['daily_read_volume'] + spark_df['daily_write_volume']) / 1024 ** 5,
             label="Spark", color='orange', linestyle='--')
    plot_manual_linear_fit(spark_df['days_elapsed'].values, (spark_df['daily_read_volume'] + spark_df['daily_write_volume']).values / 1024 ** 5, "Trend line", plt.gca(), 'orange')
   # plt.plot(spark_df['days_elapsed'], spark_df['daily_write_volume'] / 1024 ** 5, label="Spark Write", color='purple', linestyle='--')

    plt.text(50, 55, "30% annual increase", fontsize=font_size - 2, color='red')
    plt.annotate('', xy=(58, 105), xytext=(55, 65),
                 arrowprops=dict(arrowstyle="->", color='red', lw=2))

    plt.annotate('', xy=(60, 25), xytext=(55, 50),
                 arrowprops=dict(arrowstyle="->", color='red', lw=2))

    plt.xlabel("Day", fontsize=font_size)
    plt.ylabel("Daily Traffic (PB)", fontsize=font_size)
    #plt.legend(fontsize=font_size - 3, ncol=2)
    plt.gca().yaxis.set_major_formatter(FuncFormatter(format_ticks))
    plt.ylim(bottom=0, top=150)
    plt.xlim(0, 110)
    plt.grid()
    plt.tight_layout()
    plt.savefig(f"daily_traffic_volume.{format}", bbox_inches='tight')

def draw_PR_heuristics(double=True):
    # Define heuristics and data
    heuristics = [
        "No\nRep", "Access Traffic\nVolume", "Inverse\nDataset Size",
        "Job Access\nFrequency", "Access Traffic\nDensity", "Job Access\nDensity (Moirai)"
    ]

    replication_times = [(93+254+94)/3, (120+181+15)/3, (20+111+11)/3, (11+87+34)/3, (55+119+48)/3, (2+3+1)/3]  # Approximate from table (h)
    num_edges = [1252, 1102, 855, 763, 819, 509]  # In K
    num_jobs = [356, 350, 307, 322, 330, 256]  # In K
    num_tables = [134, 133, 113, 133, 133, 119]  # In K

    x = np.arange(len(heuristics))  # the label locations

    if double:
        width = 0.25
        fig, axes = plt.subplots(2, 1, figsize=(15, 6), sharex=True)
        ax1, ax2 = axes

        ax1.bar(x, replication_times, width, color='skyblue')
        ax1.set_ylabel("Optimization\nTime (hr)")
        ax1.set_ylim(0, 168)
        ax1.grid(axis='y')

        ax2.bar(x - width, num_edges, width, label="# Edges", color='salmon', hatch='/')
        ax2.bar(x, num_jobs, width, label="# Jobs", color='mediumseagreen', hatch='\\')
        ax2.bar(x + width, num_tables, width, label="# Tables", color='mediumpurple', hatch='|')

        ax2.set_ylabel("Count (K)")
        ax2.set_ylim(0, 1500)
        ax2.set_xticks(x)
        ax2.set_xticklabels(heuristics)
        ax2.grid(axis='y', linestyle='--', alpha=0.6)
        ax2.legend()
    else:
        width = 0.18
        fig, ax1 = plt.subplots(figsize=(14, 4.5))
        ax2 = ax1.twinx()

        # Adjust positions so all 4 bars are shown for each heuristic
        ax1.bar(x - 1.7 * width, replication_times, width, label="Optimization Time", color=colors_default[0])
        ax1.legend(loc='upper left', fontsize=font_size + 3)

        for i, time in enumerate(replication_times):
            ax1.text(
                x[i] - 2.4 * width if (i == 0 or i == 1) else x[i] - 2 * width, time + 3,  # position slightly above the bar
                f"{round(time)}hr", ha='center', va='bottom', fontsize=font_size + 1
            )



        ax2.bar(x - 0.4 * width, num_edges, width, label="# Edges", color=colors_default[1], hatch='/')
        ax2.bar(x + 0.6 * width, num_jobs, width, label="# Jobs", color=colors_default[2], hatch='\\')
        ax2.bar(x + 1.6 * width, num_tables, width, label="# Tables", color=colors_default[3], hatch='|')
        ax2.legend(loc='upper right', fontsize=font_size+3)



        ax1.set_ylabel("Optimization\nTime (hr)", fontsize=font_size + 2)
        ax1.set_ylim(0, 250)
        # ax1.grid(axis='y')
        yticks = ax1.get_yticks()
        ax1.set_yticks(yticks)
        ax1.set_yticklabels([int(tick) for tick in yticks], fontsize=font_size + 2)

        ax2.set_ylabel("Count (K)", fontsize=font_size + 2)
        ax2.set_ylim(0, 1500)

        ax1.set_xticks(x)
        ax1.set_xticklabels(heuristics, rotation=10)
        ax2.grid(axis='y', linestyle='--', alpha=0.6)
        ax2.set_yticklabels([int(x) for x in ax2.get_yticks()], fontsize=font_size + 2)

        # # Merge legends from both axes
        # bars, labels = [], []
        # for ax in [ax1, ax2]:
        #     h, l = ax.get_legend_handles_labels()
        #     bars += h
        #     labels += l
        # ax2.legend(bars, labels, loc='upper right', fontsize=font_size + 2)

    plt.tight_layout()
    plt.savefig(f"optimization_time_{double}.pdf", bbox_inches='tight')


def draw_edges_cdf():
    def calculate_percentiles(data, percentiles):
        if len(data) > 0:
            return np.percentile(data, percentiles)
        else:
            return [np.nan] * len(percentiles)

    def sample_cdf(data, num_points=1000):
        """Compute CDF and sample it at `num_points` evenly spaced intervals."""
        if len(data) == 0:
            return np.array([]), np.array([])

        sorted_data = np.sort(data)
        cdf = np.arange(1, len(sorted_data) + 1) / len(sorted_data)

        # Select `num_points` evenly spaced indices
        sample_indices = np.linspace(0, len(sorted_data) - 1, num_points, dtype=int)

        return sorted_data[sample_indices], cdf[sample_indices]

    def calculate_cdf(data):
        """Calculate CDF from data."""
        sorted_data = np.sort(data)
        cdf = np.arange(1, len(sorted_data) + 1) / len(sorted_data)
        return sorted_data, cdf

    cdf_cache_file = "cdf_results.csv"

    if os.path.exists(cdf_cache_file):
        print("Loading cached CDF results...")
        cdf_results = pd.read_csv(cdf_cache_file)
        presto_x = cdf_results['presto_x'].dropna().values
        presto_cdf = cdf_results['presto_cdf'].dropna().values
        spark_x = cdf_results['spark_x'].dropna().values
        spark_cdf = cdf_results['spark_cdf'].dropna().values
        table_presto_x = cdf_results['table_presto_x'].dropna().values
        table_presto_cdf = cdf_results['table_presto_cdf'].dropna().values
        table_spark_x = cdf_results['table_spark_x'].dropna().values
        table_spark_cdf = cdf_results['table_spark_cdf'].dropna().values
    else:
        print("Computing CDFs...")
        start_date = datetime(2024, 10, 22)
        end_date = datetime(2024, 10, 28)

        job_presto_counts = []
        job_spark_counts = []

        for date in date_range(start=start_date, end=end_date, freq='D'):
            print("Processing", date.strftime("%Y-%m-%d"))
            presto_path = f"../jobTraces/{date.strftime('%Y%m%d')}-Presto.csv"
            spark_path = f"../jobTraces/{date.strftime('%Y%m%d')}-Spark.csv"

            if os.path.exists(presto_path):
                presto_df = pd.read_csv(presto_path)
                job_presto_counts.extend(presto_df.groupby('job_id')[['db_name', 'table_name']].nunique().sum(axis=1))
            else:
                print(f"Missing file: {presto_path}")

            if os.path.exists(spark_path):
                spark_df = pd.read_csv(spark_path)
                job_spark_counts.extend(spark_df.groupby('job_id')[['db_name', 'table_name']].nunique().sum(axis=1))
            else:
                print(f"Missing file: {spark_path}")

        table_presto_counts = []
        table_spark_counts = []

        presto_path = f"../newTraces/report-abFP-volume-table-{start_date.strftime('%Y%m%d')}-{end_date.strftime('%Y%m%d')}-Presto.csv"
        spark_path = f"../newTraces/report-abFP-volume-table-{start_date.strftime('%Y%m%d')}-{end_date.strftime('%Y%m%d')}-Spark.csv"

        if os.path.exists(presto_path):
            presto_df = pd.read_csv(presto_path)
            table_presto_counts.extend(
                presto_df.groupby('abstractFingerPrint')[['db_name', 'table_name']].nunique().sum(axis=1))
        else:
            print(f"Missing file: {presto_path}")

        if os.path.exists(spark_path):
            spark_df = pd.read_csv(spark_path)
            table_spark_counts.extend(
                spark_df.groupby('abstractFingerPrint')[['db_name', 'table_name']].nunique().sum(axis=1))
        else:
            print(f"Missing file: {spark_path}")

        # Compute CDFs
        presto_x, presto_cdf = sample_cdf(job_presto_counts)
        spark_x, spark_cdf = sample_cdf(job_spark_counts)
        table_presto_x, table_presto_cdf = sample_cdf(table_presto_counts)
        table_spark_x, table_spark_cdf = sample_cdf(table_spark_counts)

        # Define percentiles to compute (P10 to P100 in steps of 5)
        percentiles = np.arange(10, 101, 5)

        # Data (assumed loaded from the script)
        distributions = {
            "Presto Job": presto_x,
            "Spark Job": spark_x,
            "Presto Template": table_presto_x,
            "Spark Template": table_spark_x
        }

        # Compute percentiles
        percentile_results = {
            dist_name: calculate_percentiles(data, percentiles)
            for dist_name, data in distributions.items()
        }

        # Convert to DataFrame for display
        percentile_df = pd.DataFrame(percentile_results, index=[f"P{p}" for p in percentiles])
        percentile_df.to_csv("cdf_percentiles.csv", index=True)


        # # Save CDF results
        # cdf_df = pd.DataFrame({
        #     'presto_x': np.pad(presto_x, (0, max(0, len(table_presto_x) - len(presto_x))), 'constant', constant_values=np.nan),
        #     'presto_cdf': np.pad(presto_cdf, (0, max(0, len(table_presto_cdf) - len(presto_cdf))), 'constant', constant_values=np.nan),
        #     'spark_x': np.pad(spark_x, (0, max(0, len(table_spark_x) - len(spark_x))), 'constant', constant_values=np.nan),
        #     'spark_cdf': np.pad(spark_cdf, (0, max(0, len(table_spark_cdf) - len(spark_cdf))), 'constant', constant_values=np.nan),
        #     'table_presto_x': table_presto_x,
        #     'table_presto_cdf': table_presto_cdf,
        #     'table_spark_x': table_spark_x,
        #     'table_spark_cdf': table_spark_cdf
        # })
        # cdf_df.to_csv(cdf_cache_file, index=False)
        # print("CDF results saved.")

    # Calculate mean values
    mean_presto_jobs = np.mean(presto_x) if len(presto_x) > 0 else 0
    mean_spark_jobs = np.mean(spark_x) if len(spark_x) > 0 else 0
    mean_presto_tables = np.mean(table_presto_x) if len(table_presto_x) > 0 else 0
    mean_spark_tables = np.mean(table_spark_x) if len(table_spark_x) > 0 else 0

    print(f"Mean # of Tables per Presto Job: {mean_presto_jobs:.2f}")
    print(f"Mean # of Tables per Spark Job: {mean_spark_jobs:.2f}")
    print(f"Mean # of Tables per Presto Template: {mean_presto_tables:.2f}")
    print(f"Mean # of Tables per Spark Template: {mean_spark_tables:.2f}")

    # Plot job-level CDF
    plt.figure(figsize=(8, 5))
    plt.plot(presto_x, presto_cdf, label="Presto", linestyle='-', marker='.')
    plt.plot(spark_x, spark_cdf, label="Spark", linestyle='-', marker='.')
    plt.xscale("log")  # Set x-axis to log scale
    plt.xlabel("# of Tables per Job")
    plt.ylabel("Fraction of jobs (CDF)")
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.savefig("degree_cdf_job.pdf")

    # Plot table-level CDF
    plt.figure(figsize=(8, 5))
    plt.plot(table_presto_x, table_presto_cdf, label="Presto", linestyle='-', marker='.')
    plt.plot(table_spark_x, table_spark_cdf, label="Spark", linestyle='-', marker='.')
    plt.xscale("log")  # Set x-axis to log scale
    plt.xlabel("# of Tables per Template")
    plt.ylabel("Fraction of jobs (CDF)")
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.savefig("degree_cdf_template.pdf")

def draw_reorg():
    dfs = {
        "Redistribution cost unaware": pd.read_csv("../long_term/placement_unaware.csv"),
        "Moirai": pd.read_csv("../long_term/placement_moirai.csv"),
    }

    fig, ax = plt.subplots(figsize=(11, 5))

    # Define line styles and markers for black-and-white friendly plotting
    line_styles = ['--', '--', '-', '-']

    # Placeholder for cumulative data
    cumulative_data = {
        "Redistribution cost unaware (Ingress)": [],
        "Redistribution cost unaware (Egress)": [],
        "Moirai (Ingress)": [],
        "Moirai (Egress)": []
    }

    for label, df in dfs.items():
        # Convert Ingress and Egress to bytes
        df['Ingress_Bytes'] = df['Ingress'].apply(parse_size).cumsum()
        df['Egress_Bytes'] = df['Egress'].apply(parse_size).cumsum()

        cumulative_data[f"{label} (Ingress)"] = df['Ingress_Bytes']
        cumulative_data[f"{label} (Egress)"] = df['Egress_Bytes']

        # Plot the Ingress and Egress data
        ax.plot(df['On-premises'], df['Ingress_Bytes'], label=f'{label} (ingress)',
                linestyle=line_styles.pop(0), color='blue', linewidth=2, marker='x', markersize=12)
        ax.plot(df['On-premises'], df['Egress_Bytes'], label=f'{label} (egress)',
                linestyle=line_styles.pop(0), color='green', linewidth=2, marker='*', markersize=12)

        if label == "Redistribution cost unaware":
            best_case_x = df['On-premises']
            best_case_y = (1 - df['On-premises'] / 100) * (299.12 * 1024 ** 5)
            ax.plot(best_case_x, best_case_y, 'r-', label='Best case (ingress)', linewidth=6)

    # Customize the plot
    ax.set_xlabel('On-premises Storage Space / Total Data Size (%)', fontsize=font_size)
    ax.set_ylabel('Traffic Volume (PB)', fontsize=font_size)
    ax.tick_params(axis='x', labelsize=font_size - 2)
    ax.tick_params(axis='y', labelsize=font_size - 2)
    ax.grid(True)

    # Add legend
    ax.legend(loc='upper center',
              bbox_to_anchor=(0.5, 1.35), fontsize=font_size - 2,
              ncol=2, frameon=False)

    # Set y-axis ticks and labels
    ticks = [x for x in range(0, 55, 5)]
    yticks = [i * 10 * 1024 ** 5 for i in ticks]
    ytick_labels = ["0"] + [f"{i * 10}" if i % 2 == 0 else "" for i in ticks[1:]]
    ax.set_ylim(ymax=450 * 1024 ** 5)
    ax.set_yticks(yticks)
    ax.set_yticklabels(ytick_labels, fontsize=font_size - 2)

    ax.set_xlim(0, 90)
    ax.set_xticks([90, 80, 70, 60, 50, 40, 30, 20, 10, 0])
    ax.set_xticklabels(["90%", "80%", "70%", "60%", "50%", "40%", "30%", "20%", "10%", "0%"], fontsize=font_size - 2)

    # Invert x-axis to start from large to low
    ax.invert_xaxis()

    plt.tight_layout()
    plt.savefig('migration.pdf', bbox_inches='tight')

if __name__ == '__main__':
    replication_effects()
    draw_traffic_rate(single=False)
    draw_job_routing()
    draw_growth(format="pdf")

    draw_overall_new(front=True)
    draw_overall_new(job=True)
    draw_overall_new()
    draw_overall_new(pr=True)

    draw_PR_heuristics(double=False)
    draw_edges_cdf()
    draw_reorg()

    """ was not used in submission, for debugging """
    # verify_traffic_rate(yugong=True)
    # for week_id in range(9, 14):
    #     plot_weekly_traffic(week_id=week_id, yugong=True)
    # pass


