import argparse
import concurrent.futures
import fcntl
import logging
import os
import re
import shutil
from collections import defaultdict, OrderedDict
from datetime import datetime, timedelta
from enum import Enum

import numpy as np
import pandas as pd

from Yugong.Ownership import Ownership
from utility import human_readable_size, to_seconds, parse_size

# Define command-line arguments
parser = argparse.ArgumentParser(description="Test optimization with upcoming jobs")
parser.add_argument("--c", type=int, default=30, help="Portion of compute to cloud")
parser.add_argument("--num_week", type=int, default=1, help="# of weeks for evaluation")
parser.add_argument("--opt_path", type=str, default="sample_1.000", help="Optimization results stored under this path")
parser.add_argument("--debug", action="store_true", help="debug mode might just run part of the traces")
parser.add_argument("--yugong", action="store_true", help="Yugong mode: project-based job placement")
parser.add_argument("--simple", action="store_true", help="omit traffic rate calculation")
parser.add_argument("--policy", type=str, default="size-predict", help="scheduling policy",
                    choices=["size-predict", "size-aware", "size-unaware", "independent"])

args = parser.parse_args()

class JobType(Enum):
    SPARK = "spark"
    PRESTO = "presto"

class Stat:
    def __init__(self, help=None):
        self.count = 0
        self.cputime = 0
        self.inputDataSize = 0
        self.outputDataSize = 0
        self.help = help

    def add(self, cputime=0, inputDataSize=0, outputDataSize=0):
        self.count += 1
        self.cputime += cputime
        self.inputDataSize += inputDataSize
        self.outputDataSize += outputDataSize

    def count(self):
        return self.count

    def get_cputime(self):
        return self.cputime

    def inputDataSize(self):
        return self.inputDataSize

    def outputDataSize(self):
        return self.outputDataSize

    def print(self):
        log_str = f"{self.help}: count {self.count}"
        if self.cputime > 0:
            log_str += f", cputime {self.cputime:.4g}"
        if self.inputDataSize > 0:
            log_str += f", inputDataSize {human_readable_size(self.inputDataSize)}"
        if self.outputDataSize > 0:
            log_str += f", outputDataSize {human_readable_size(self.outputDataSize)}"

        logging.info(log_str)

    def clean(self):
        self.count = 0
        self.cputime = 0
        self.inputDataSize = 0
        self.outputDataSize = 0

class Scheduler:
    def __init__(self, dir_path, table_size_path, weight_lookup=None,
                 yugong=False, ownership=None):
        assert os.path.exists(dir_path), f"dir_path={dir_path} not found"
        self.dir_path = dir_path
        logging.info(f"dir_path: {dir_path}, yugong mode: {yugong}")

        self.yugong = yugong
        if self.yugong:
            assert ownership is not None, "ownership is None"
            self.ownership = ownership

        # load template and dataset placements
        # query_kv = self._load_query_placement()
        # self.query_map = Cache(maxsize=len(query_kv) // 100, kv_store=query_kv)
        # logging.info(f"# of query placements: {len(query_kv)}")
        self.query_map = self._load_query_placement()
        logging.info(f"# of query placements: {len(self.query_map)}")
        self.dataset_map = self._load_dataset_placement()
        logging.info(f"# of dataset placements: {len(self.dataset_map)}")

        # load table sizes
        logging.info(f"Loaded table size: {table_size_path}")

        self.table_size_map = self._load_table_sizes(table_size_path)

        self.size_lookup = None
        self.db_table_size = None

        self.weight_lookup = weight_lookup

        # historical workload
        self.stat_cloud_query = Stat("cloud queries")
        self.stat_on_prem_query = Stat("on-prem queries")

        self.stat_categories = {
            "both_sides": Stat("Old queries with tables on both sides"), # all_table_local & all_table_cloud == True
            "only_cloud": Stat("Old queries with all tables on cloud"), # all_table_local == False and all_table_cloud == True
            "only_onprem": Stat("Old queries with all tables on-prem"), # all_table_local == True and all_table_cloud == False
            "needs_transfer": Stat("Old queries requiring ingress/egress"), # all_table_local == False and all_table_cloud == False
        }

        self.stat_categories_new = {
            "both_sides": Stat("New queries with tables on both sides"),  # all_table_local & all_table_cloud == True
            "only_cloud": Stat("New queries with all tables on cloud"),
            # all_table_local == False and all_table_cloud == True
            "only_onprem": Stat("New queries with all tables on-prem"),
            # all_table_local == True and all_table_cloud == False
            "needs_transfer": Stat("New queries requiring ingress/egress"),
            # all_table_local == False and all_table_cloud == False
        }

    def get_cloud_computation_ratio(self) -> float:
        total_cputime = self.stat_cloud_query.get_cputime() + self.stat_on_prem_query.get_cputime()
        return self.stat_cloud_query.get_cputime() / total_cputime if total_cputime else 0.0

    def stats(self):
            self.stat_cloud_query.print()
            self.stat_on_prem_query.print()
            for stat in self.stat_categories.values():
                stat.print()
            for stat in self.stat_categories_new.values():
                stat.print()

    """
    Determine query placement based on dataset distribution and policy.
    Input: table_volume_list, format: [(table_name, input_volume, output_volume), ...]
    Returns: (placement_y, remote traffic)
    """
    def place_query(self, template_id, cputime, table_volume_list,
                    policy='size-predict', target_cloud_cpu_ratio=None, info=None):
        placement_y = self.query_map.get(template_id, None)
        #if self.yugong:
            #print(f"template_id: {template_id}, cputime: {cputime}, table_volume_list: {table_volume_list}, placement_y: {placement_y}")
            #assert placement_y is not None, f"project name {template_id} not found in query_map"

        all_tables_local, all_tables_cloud = True, True
        table_zw_map = {} # group-aware, i.e., identify the location of a table even if packed in ".group"

        for table, _, _ in table_volume_list:
            db_name, _, table_name = table.partition('.')
            group_name = self.ownership.get_table_ownership(table) if self.yugong else db_name

            if table not in self.table_size_map and table not in self.dataset_map:
                # can assume this table is small enough
                continue
            elif table not in self.dataset_map and f"{group_name}.group" not in self.dataset_map:
                logging.warning(f"table {table} and {group_name}.group not found in dataset_map")
                continue
            elif table not in self.dataset_map:
                on_prem, cloud = self.dataset_map[f"{group_name}.group"]
            else:
                on_prem, cloud = self.dataset_map[table]

            # dataset_key = table if table in self.dataset_map else f"{group_name}.group"
            # if dataset_key not in self.dataset_map:
            #     logging.warning(f"Table {table} and group {dataset_key} not found in dataset_map")
            #     continue
            #
            # on_prem, cloud = self.dataset_map[dataset_key]

            table_zw_map[table] = (on_prem, cloud)

            all_tables_local &= on_prem == 0
            all_tables_cloud &= cloud == 0

        #print(f"table_zw_map: {table_zw_map}", flush=True) # debug
        # Compute cloud ratio
        cloud_ratio = self.get_cloud_computation_ratio()

        # debugging
        # if policy == 'size-predict':
        #     placement_y = None

        input_volume = sum(volume for _, volume, _ in table_volume_list)
        output_volume = sum(volume for _, _, volume in table_volume_list)
        # total_volume = sum(volume for _, volume in table_volume_list)
        if placement_y is not None:
            category = "both_sides" if all_tables_local and all_tables_cloud else \
                "only_cloud" if all_tables_local else \
                    "only_onprem" if all_tables_local else \
                        "needs_transfer"
            self.stat_categories[category].add(cputime=cputime, inputDataSize=input_volume, outputDataSize=output_volume)

            # Adjust placement decision if needed
            if not self.yugong:
                if placement_y == 0 and all_tables_cloud and cloud_ratio < target_cloud_cpu_ratio:
                    placement_y = 1
                elif placement_y == 1 and all_tables_local and cloud_ratio > target_cloud_cpu_ratio:
                    placement_y = 0
        else: # New query classification
            category = "both_sides" if all_tables_local and all_tables_cloud else \
                "only_cloud" if all_tables_local else \
                    "only_onprem" if all_tables_local else \
                        "needs_transfer"
            self.stat_categories_new[category].add(cputime=cputime, inputDataSize=input_volume, outputDataSize=output_volume)

            if policy == "independent":
                placement_y = 1 if cloud_ratio < target_cloud_cpu_ratio else 0
            elif policy in ['size-predict', 'size-aware', 'size-unaware']:
                if all_tables_local and all_tables_cloud:
                    placement_y = 1 if cloud_ratio < target_cloud_cpu_ratio else 0
                elif all_tables_cloud:
                    placement_y = 1
                elif all_tables_local:
                    placement_y = 0
                # TODO: remove the magic number 0.05
                elif cloud_ratio < target_cloud_cpu_ratio - 0.05:
                    placement_y = 1
                elif cloud_ratio > target_cloud_cpu_ratio + 0.05:
                    placement_y = 0
                else:
                    traffic_if_executed_cloud = 0
                    traffic_if_executed_on_prem = 0
                    for table, input_access, output_access in table_volume_list:
                        if table not in table_zw_map:
                            continue # work-around: this table should be small and cold

                        if policy == 'size-predict':
                            weight = self.weight_lookup.get(table, 1) # set to 1 Byte, effectively omitted this table as it should be cold
                        elif policy == 'size-aware':
                            weight = input_access+output_access
                        else:  # size-unaware
                            weight = 1
                        if table_zw_map[table][0] == 1:
                            traffic_if_executed_on_prem += weight
                        if table_zw_map[table][1] == 1:
                            traffic_if_executed_cloud += weight

                    placement_y = 1 if traffic_if_executed_cloud < traffic_if_executed_on_prem else 0
            else:
                raise ValueError(f"Unknown policy: {policy}")

        # Update stats
        (self.stat_cloud_query if placement_y == 1 else self.stat_on_prem_query).add(cputime=cputime,
                                                                                     inputDataSize=input_volume,
                                                                                     outputDataSize=output_volume)

        egress = 0
        ingress = 0
        if placement_y == 0: # job executed on-prem
            for table, input_access, output_access in table_volume_list:
                if table not in table_zw_map:
                    continue
                if table_zw_map[table][0] == 1: # but data cannot be found on-prem, egress
                    egress += input_access
                    ingress += output_access
                    # if volume > 0:
                    #     if self.yugong is False:
                    #         logging.info(f"st {info} fp {abFP} create egress {human_readable_size(volume)} to {table}")
        else:  # cloud
            assert placement_y == 1 # job executed on cloud
            for table, input_access, output_access in table_volume_list:
                if table not in table_zw_map:
                    continue
                if table_zw_map[table][1] == 1:  # but data cannot be found on cloud
                    ingress += input_access
                    egress += output_access
                    # if volume > 0:
                    #     if self.yugong is False:
                    #        logging.info(f"st {info} fp {abFP} create ingress {human_readable_size(volume)} to {table}")
        return placement_y, egress, ingress

    def _load_query_placement(self):
        file_path = os.path.join(self.dir_path, 'query_placement.csv')
        if not os.path.exists(file_path):
            return {}

        df = pd.read_csv(file_path, delimiter=',', on_bad_lines='warn')
        df['abFP'] = df['abFP'].astype(str)
        df['y'] = df['y'].astype(int)
        return df.set_index('abFP')['y'].to_dict()

    def _load_dataset_placement(self):
        df = pd.read_csv(os.path.join(self.dir_path, 'dataset_placement.csv'),
                         dtype={'table': str, 'size': float})
        df['z'] = df['z'].astype(int)
        df['w'] = df['w'].astype(int)
        df['table'] = df['table'].astype(str)

        logging.info(
            f"datalake size according to dataset_map: {human_readable_size(df['size'].sum() * 1024**3)}")
        # dict_of_placement = df.set_index('table')[['z', 'w']].to_dict(orient='index')
        # Ensure we return actual z and w values
        dataset_map = {row['table']: (row['z'], row['w']) for _, row in df.iterrows()}

        return dataset_map

    def _load_table_sizes(self, table_size_path):
        """Load table sizes and filter out zero-size entries."""
        df = pd.read_csv(table_size_path)
        df['hive_database_name'] = df['hive_database_name'].astype(str)
        df['hive_table_name'] = df['hive_table_name'].astype(str)

        df = df[df['dir_size'] > 0]  # Exclude tables with zero size
        df['full_table_name'] = df['hive_database_name'] + '.' + df['hive_table_name']
        return df.set_index('full_table_name')['dir_size'].to_dict()

def process_baseline(baseline: str, dir_path: str, num_of_week: int, c: int,
                     rep_rate: float,
                     traffic_rate_disabled: bool = False,
                     ):
    period_day = 7
    policy = "size-predict"
    setup_logger(os.path.join(dir_path, f'routing.txt'))
    period_start = datetime.strptime("2024-10-22", "%Y-%m-%d")
    logging.info(f"Preparing the first df starting from {period_start}")

    # Header: start_time: str,job_id,template_id,duration,
    # uown_names,inputDataSize,outputDataSize,cputime, type
    df_presto = pd.concat([read_Presto(period_start + timedelta(days=i)) for i in range(period_day)])
    df_spark = pd.concat([read_Spark(period_start + timedelta(days=i)) for i in range(period_day)])
    df = pd.concat([df_spark, df_presto])
    df['totalDataSize'] = df['inputDataSize'] + df['outputDataSize']
    weight_group = df.groupby(['table']).agg(
        totalDataSize=('totalDataSize', 'mean')).reset_index()
    weight_lookup = weight_group.set_index('table').to_dict()['totalDataSize']

    logging.info(f"# of jobs: {len(df['job_id'].unique())}")

    period_start = period_start + timedelta(days=period_day)

    """ to calculate traffic rate per minute, """
    minute_buckets = OrderedDict()  # OrderedDict keeps minute order for easy popping

    # store logs for each period
    period_logs = []

    for period_offset in range(num_of_week):
        start_date = period_start + timedelta(days=period_offset * period_day)
        df_presto = pd.concat([read_Presto(start_date + timedelta(days=i)) for i in range(period_day)])
        df_spark = pd.concat([read_Spark(start_date + timedelta(days=i)) for i in range(period_day)])
        df = pd.concat([df_spark, df_presto])

        df['totalDataSize'] = df['inputDataSize'] + df['outputDataSize']
        df = df.sort_values(['start_time', 'job_id'])
        logging.info(f"Week {period_offset + 1}, starting on {start_date}")
        logging.info(f"# of jobs: {len(df['job_id'].unique())}")

        jobs = df.groupby(['start_time', 'job_id'])

        scheduler = Scheduler(dir_path=os.path.join(dir_path),
                  table_size_path='report-table-size-20241021.csv',
                  weight_lookup=weight_lookup)  # TODO: stateful between periods

        egress_byte_Presto = 0
        ingress_byte_Presto = 0
        egress_byte_Spark = 0
        ingress_byte_Spark = 0

        # enumerate jobs
        for (start_time, job_id), group in jobs:
            job_type = group['type'].iloc[0]
            if job_type == JobType.SPARK:
                cputime = group['cputime'].iloc[0]
            else:
                cputime = group['cputime'].sum()
            template_id = group['template_id'].iloc[0]
            table_volume_list = [(row['table'], row['inputDataSize'], row['outputDataSize']) for _, row in
                                 group.iterrows()]

            placement_y, egress_byte, ingress_byte = scheduler.place_query(template_id, cputime, table_volume_list,
                                                                           policy=policy,
                                                                           target_cloud_cpu_ratio=c / 100,
                                                                           info=start_time)
            if baseline == "rep_x_month":
                egress_byte *= rep_rate
                ingress_byte *= rep_rate
            if job_type == JobType.SPARK:
                egress_byte_Spark += egress_byte
                ingress_byte_Spark += ingress_byte
            else:
                egress_byte_Presto += egress_byte
                ingress_byte_Presto += ingress_byte

            if not traffic_rate_disabled:
                """ traffic rate """
                duration = group['duration'].iloc[0]
                if job_type == JobType.SPARK:
                    tStart = datetime.strptime(start_time, "%Y-%m-%d %H:%M:%S.%f")
                else:
                    tStart = datetime.strptime(start_time, "%Y-%m-%d %H:%M:%S")
                tEnd = tStart + timedelta(seconds=duration)
                start_minute = tStart.replace(second=0, microsecond=0)
                end_minute = (tEnd + timedelta(seconds=59)).replace(second=0, microsecond=0)

                # Flush expired minute buckets (older than job_start_minute)
                # TODO: we can not flush as 'start_time' (str) is not the correct index
                # flush_oldest_minute_buckets(minute_buckets, start_minute, os.path.join(dir_path, f"c{c}"))

                min = start_minute
                total_minute = (end_minute - start_minute).total_seconds() / 60
                while min < end_minute:
                    if min not in minute_buckets:
                        minute_buckets[min] = {'egress_byte_Presto': 0, 'ingress_byte_Presto': 0,
                                               'egress_byte_Spark': 0, 'ingress_byte_Spark': 0}
                        # minute_buckets[min] = {'egress_byte': 0, 'ingress_byte': 0}
                    if job_type == JobType.SPARK:
                        minute_buckets[min]['egress_byte_Spark'] += egress_byte / total_minute
                        minute_buckets[min]['ingress_byte_Spark'] += ingress_byte / total_minute
                    else:
                        minute_buckets[min]['egress_byte_Presto'] += egress_byte / total_minute
                        minute_buckets[min]['ingress_byte_Presto'] += ingress_byte / total_minute
                    # minute_buckets[min]['egress_byte'] += egress_byte / total_minute
                    # minute_buckets[min]['ingress_byte'] += ingress_byte / total_minute
                    min += timedelta(minutes=1)

        new_weight_group = df.groupby(['table']).agg(
            totalDataSize=('totalDataSize', 'mean')).reset_index()
        new_weight_lookup = new_weight_group.set_index('table').to_dict()['totalDataSize']
        weight_lookup.update(new_weight_lookup)

        logging.info(f"Egress {human_readable_size(egress_byte_Presto + egress_byte_Spark)}: "
                     f"Presto {human_readable_size(egress_byte_Presto)}, Spark {human_readable_size(egress_byte_Spark)}")
        logging.info(f"Ingress {human_readable_size(ingress_byte_Presto + ingress_byte_Spark)}: "
                     f"Presto {human_readable_size(ingress_byte_Presto)}, Spark {human_readable_size(ingress_byte_Spark)}")
        # logging.info(f"hit rate: {scheduler.query_map.hit_rate()}")

        # Log period statistics
        period_logs.append({
            "start_date": start_date,
            "end_date": start_date + timedelta(days=period_day - 1),
            "scheduling_policy": policy,
            "c": c,
            "cloud_compute_ratio": scheduler.get_cloud_computation_ratio(),  # Store only the ratio
            "egress_byte_Presto": egress_byte_Presto,
            "ingress_byte_Presto": ingress_byte_Presto,
            "egress_byte_Spark": egress_byte_Spark,
            "ingress_byte_Spark": ingress_byte_Spark,
            "dir_path": dir_path,
            "opt_dir_path": None
        })

    if not traffic_rate_disabled:
        # Flush remaining minute buckets
        flush_oldest_minute_buckets(minute_buckets, None, dir_path)

        # Now log all stored period statistics in a single batch
    for log_entry in period_logs:
        log_period_statistics(
            log_entry["start_date"],
            log_entry["end_date"],
            log_entry["scheduling_policy"],
            log_entry["c"],
            log_entry["cloud_compute_ratio"],  # Only store ratio instead of full scheduler object
            log_entry["egress_byte_Presto"],
            log_entry["ingress_byte_Presto"],
            log_entry["egress_byte_Spark"],
            log_entry["ingress_byte_Spark"],
            # log_entry["egress_byte"],
            # log_entry["ingress_byte"],
            log_entry["dir_path"],
            log_entry["opt_dir_path"],
            traffic_rate_disabled=traffic_rate_disabled,
            rep_rate=rep_rate
        )


def process_jobs(c, num_of_week, dir_path, debug, policy, traffic_rate_disabled=False):
    # create c30 or c50 or c70 directory under parent to store traffic_rate
    if os.path.exists(os.path.join(dir_path, f"c{c}")) and not traffic_rate_disabled:
        shutil.rmtree(os.path.join(dir_path, f"c{c}"), ignore_errors=False)
    os.makedirs(f"{dir_path}/c{c}", exist_ok=True)
    period_day = 7

    setup_logger(os.path.join(dir_path, f'routing_c{c}_{policy}.txt'))
    logging.info(f"Start processing jobs with c={c}, num_of_week={num_of_week}, dir_path={dir_path}, debug={debug}, policy={policy}, traffic_rate_disabled={traffic_rate_disabled}")
    # prepare scheduler
    period_start = datetime.strptime("2024-10-22", "%Y-%m-%d")
    logging.info(f"Preparing the first df starting from {period_start}")
    # Header: start_time: str,job_id,template_id,duration,
    # uown_names,inputDataSize,outputDataSize,cputime, type
    df_presto = pd.concat([read_Presto(period_start + timedelta(days=i)) for i in range(period_day)])
    df_spark = pd.concat([read_Spark(period_start + timedelta(days=i)) for i in range(period_day)])
    df = pd.concat([df_spark, df_presto])
    df['totalDataSize'] = df['inputDataSize'] + df['outputDataSize']
    # df = df.sort_values(['datetime', 'job_id'])
    weight_group = df.groupby(['table']).agg(
            totalDataSize=('totalDataSize', 'mean')).reset_index()
    weight_lookup = weight_group.set_index('table').to_dict()['totalDataSize']

    logging.info(f"# of jobs: {len(df['job_id'].unique())}")

    period_start = period_start + timedelta(days=period_day)

    """ to calculate traffic rate per minute, """
    minute_buckets = OrderedDict()  # OrderedDict keeps minute order for easy popping

    # store logs for each period
    period_logs = []

    for period_offset in range(num_of_week):
        start_date = period_start + timedelta(days=period_offset * period_day)
        if period_offset == 0:
            label = ""
        else:
            label = "_" + (start_date - timedelta(days=period_day)).strftime("%m%d")

        # Header: start_time,job_id,template_id,duration,
        # uown_names,inputDataSize,cputime, type
        # TODO: this can be parallelized
        df_presto = pd.concat([read_Presto(start_date + timedelta(days=i)) for i in range(period_day)])
        df_spark = pd.concat([read_Spark(start_date + timedelta(days=i)) for i in range(period_day)])

        df = pd.concat([df_spark, df_presto])
        # df["datetime"] = df.apply(
        #     lambda row: datetime.strptime(row["start_time"], "%Y-%m-%d %H:%M:%S.%f")
        #     if row["type"] == JobType.SPARK
        #     else datetime.strptime(row["start_time"], "%Y-%m-%d %H:%M:%S"),
        #     axis=1
        # )
        df['totalDataSize'] = df['inputDataSize'] + df['outputDataSize']
        df = df.sort_values(['start_time', 'job_id'])
        print("first 5 jobs", df.head())
        logging.info(f"Week {period_offset + 1}, starting on {start_date}")
        logging.info(f"# of jobs: {len(df['job_id'].unique())}")


        if debug:
            print(f"debug mode: retain first 3K rows", flush=True)
            # retain first 3K rows
            jobs = df.head(3000).groupby(['start_time', 'job_id'])
            print(jobs.head(1))
        else:
            jobs = df.groupby(['start_time', 'job_id'])

        # prepare scheduler with optimization results
        scheduler = Scheduler(dir_path=os.path.join(dir_path, f"test_run_c{c}_bw0.02_local{100-c}{label}"),
                              table_size_path='report-table-size-0907.csv' if start_date < datetime.strptime("2024-05-13", "%Y-%m-%d") else 'report-table-size-20241021.csv',
                              weight_lookup=weight_lookup) # TODO: stateful between periods
        #logging.info(f"hit rate: {scheduler.query_map.hit_rate()}")

        egress_byte_Presto = 0
        ingress_byte_Presto = 0
        egress_byte_Spark = 0
        ingress_byte_Spark = 0

        hybrid_job_count = 0
        hybrid_job_bytes = 0

        # enumerate jobs
        for (start_time, job_id), group in jobs:
            job_type = group['type'].iloc[0]
            if job_type == JobType.SPARK:
                cputime = group['cputime'].iloc[0]
            else:
                cputime = group['cputime'].sum()
            template_id = group['template_id'].iloc[0]
            table_volume_list = [(row['table'], row['inputDataSize'], row['outputDataSize']) for _, row in group.iterrows()]

            placement_y, egress_byte, ingress_byte = scheduler.place_query(template_id, cputime, table_volume_list,
                                                         policy=policy,
                                                         target_cloud_cpu_ratio=c / 100,
                                                         info=start_time)

            if job_type == JobType.SPARK:
                egress_byte_Spark += egress_byte
                ingress_byte_Spark += ingress_byte
            else:
                egress_byte_Presto += egress_byte
                ingress_byte_Presto += ingress_byte

            if egress_byte + ingress_byte > 0:
                hybrid_job_count += 1
                hybrid_job_bytes += sum([input + output for _, input, output in table_volume_list])

            if not traffic_rate_disabled:
                """ traffic rate """
                duration = group['duration'].iloc[0]
                if job_type == JobType.SPARK:
                    tStart = datetime.strptime(start_time, "%Y-%m-%d %H:%M:%S.%f")
                else:
                    tStart = datetime.strptime(start_time, "%Y-%m-%d %H:%M:%S")
                tEnd = tStart + timedelta(seconds=duration)
                start_minute = tStart.replace(second=0, microsecond=0)
                end_minute = (tEnd + timedelta(seconds=59)).replace(second=0, microsecond=0)

                # Flush expired minute buckets (older than job_start_minute)
                # TODO: we can not flush as 'start_time' (str) is not the correct index
                # flush_oldest_minute_buckets(minute_buckets, start_minute, os.path.join(dir_path, f"c{c}"))

                min = start_minute
                total_minute = (end_minute - start_minute).total_seconds() / 60
                while min < end_minute:
                    if min not in minute_buckets:
                        minute_buckets[min] = {'egress_byte_Presto': 0, 'ingress_byte_Presto': 0,
                                               'egress_byte_Spark': 0, 'ingress_byte_Spark': 0}
                        # minute_buckets[min] = {'egress_byte': 0, 'ingress_byte': 0}
                    if job_type == JobType.SPARK:
                        minute_buckets[min]['egress_byte_Spark'] += egress_byte / total_minute
                        minute_buckets[min]['ingress_byte_Spark'] += ingress_byte / total_minute
                    else:
                        minute_buckets[min]['egress_byte_Presto'] += egress_byte / total_minute
                        minute_buckets[min]['ingress_byte_Presto'] += ingress_byte / total_minute
                    # minute_buckets[min]['egress_byte'] += egress_byte / total_minute
                    # minute_buckets[min]['ingress_byte'] += ingress_byte / total_minute
                    min += timedelta(minutes=1)

        new_weight_group = df.groupby(['table']).agg(
            totalDataSize=('totalDataSize', 'mean')).reset_index()
        new_weight_lookup = new_weight_group.set_index('table').to_dict()['totalDataSize']
        weight_lookup.update(new_weight_lookup)

        logging.info(f"Egress {human_readable_size(egress_byte_Presto + egress_byte_Spark)}: "
                     f"Presto {human_readable_size(egress_byte_Presto)}, Spark {human_readable_size(egress_byte_Spark)}")
        logging.info(f"Ingress {human_readable_size(ingress_byte_Presto + ingress_byte_Spark)}: "
                     f"Presto {human_readable_size(ingress_byte_Presto)}, Spark {human_readable_size(ingress_byte_Spark)}")
        logging.info(f"# of hybrid jobs: {hybrid_job_count} with access bytes: {human_readable_size(hybrid_job_bytes)}")
        #logging.info(f"hit rate: {scheduler.query_map.hit_rate()}")

        # Log period statistics
        period_logs.append({
            "start_date": start_date,
            "end_date": start_date + timedelta(days=period_day-1),
            "scheduling_policy": policy,
            "c": c,
            "cloud_compute_ratio": scheduler.get_cloud_computation_ratio(),  # Store only the ratio
            "egress_byte_Presto": egress_byte_Presto,
            "ingress_byte_Presto": ingress_byte_Presto,
            "egress_byte_Spark": egress_byte_Spark,
            "ingress_byte_Spark": ingress_byte_Spark,
            "dir_path": dir_path,
            "opt_dir_path": os.path.join(dir_path, f"test_run_c{c}_bw0.02_local{100-c}{label}")
        })

    if not traffic_rate_disabled:
        # Flush remaining minute buckets
        flush_oldest_minute_buckets(minute_buckets, None, os.path.join(dir_path, f"c{c}"))

    # Now log all stored period statistics in a single batch
    for log_entry in period_logs:
        log_period_statistics(
            log_entry["start_date"],
            log_entry["end_date"],
            log_entry["scheduling_policy"],
            log_entry["c"],
            log_entry["cloud_compute_ratio"],  # Only store ratio instead of full scheduler object
            log_entry["egress_byte_Presto"],
            log_entry["ingress_byte_Presto"],
            log_entry["egress_byte_Spark"],
            log_entry["ingress_byte_Spark"],
            # log_entry["egress_byte"],
            # log_entry["ingress_byte"],
            log_entry["dir_path"],
            log_entry["opt_dir_path"],
            traffic_rate_disabled=traffic_rate_disabled
        )

def Moirai_weekly_cost_print(opt_path):
    dir_path = opt_path
    if not os.path.exists(dir_path) or not os.path.exists(os.path.join(dir_path, "log.csv")):
        print("No log.csv found")
        return
    # hesder: period, mode, cloud_computation_ratio, cloud_computation_target,
    # ingress_byte_Presto,egress_byte_Presto,ingress_byte_Spark,egress_byte_Spark,
    # P90_traffic_bps, P95_traffic_bps, P99_traffic_bps,
    # movement_ingress_bytes, movement_egress_bytes, rep_bytes, sample_rate
    df = pd.read_csv(os.path.join(dir_path, "log.csv"))
    df['job_ingress_bytes'] = df['ingress_byte_Presto'] + df['ingress_byte_Spark']
    df['job_egress_bytes'] = df['egress_byte_Presto'] + df['egress_byte_Spark']
    df['traffic_volume'] = (df['job_ingress_bytes'] + df['job_egress_bytes']
                            + df['movement_ingress_bytes'] + df['movement_egress_bytes'])
    df['egress_volume'] = df['job_egress_bytes'] + df['movement_egress_bytes']
    for c in df['cloud_computation_target'].unique():
        df_c = df[df['cloud_computation_target'] == c]
        print(f"Cloud computation target: {c}")
        print(f"Weekly traffic volume: {human_readable_size(df_c['traffic_volume'].mean())}")
        print("Bandwidth:", df_c['P95_traffic_bps'].max())
        network_cost = df['P95_traffic_bps'].max() / (100*1024**3) * 24 * 7
        egress_cost = df_c['egress_volume'].mean() / 1024**3 * 0.02
        rep_cost = df_c['rep_bytes'].mean() / 1024**3 * 0.023 / 4
        print(f"Network cost: {network_cost:.0f}, Egress cost: {egress_cost:.0f}, Replication cost: {rep_cost:.0f}")

def extract_movement_rep_and_sample(log_file):
    """Extracts data movement ingress and egress bytes from log.txt"""
    movement_ingress_bytes = movement_egress_bytes = 0
    replication_size = None
    sample_rate = None

    if os.path.exists(log_file):
        with open(log_file, "r") as f:
            for line in f:
                # Extract data movement bytes
                movement_match = re.search(r"Data movement:\s([\d.]+\s*[A-Z]*B) ingress,\s([\d.]+\s*[A-Z]*B) .* egress",
                                           line)
                if movement_match:
                    movement_ingress_bytes = parse_size(movement_match.group(1))
                    movement_egress_bytes = parse_size(movement_match.group(2))

                # Extract replication size (overlap)
                replication_match = re.search(r"Storage: .*?(\d+\.\d+\s*[A-Z]B)\s+overlap", line)
                if replication_match:
                    replication_size = parse_size(replication_match.group(1))

                # Extract sample rate (k=0.XXX) or k=1.000
                sample_match = re.search(r"k=(1\.000)", line)
                if sample_match:
                    sample_rate = 1.0
                sample_match = re.search(r"k=(0\.\d+)", line)
                if sample_match:
                    sample_rate = float(sample_match.group(1))  # Convert to float


    return movement_ingress_bytes, movement_egress_bytes, replication_size, sample_rate

def calculate_traffic_percentiles(traffic_dir: str, start_date: datetime, end_date: datetime, debug: bool = False):
    """Reads traffic data from CSV files and computes percentiles"""
    all_traffic_rates = []

    for single_date in pd.date_range(start_date, end_date):
        traffic_file = os.path.join(traffic_dir, f"traffic_{single_date.strftime('%Y%m%d')}.csv")
        if os.path.exists(traffic_file):
            df = pd.read_csv(traffic_file)
            df['egress_rate_bps'] = df['egress_rate_presto_bps'] + df['egress_rate_spark_bps']
            df['ingress_rate_bps'] = df['ingress_rate_presto_bps'] + df['ingress_rate_spark_bps']
            df['traffic_rate_bps'] = df['egress_rate_bps'] + df['ingress_rate_bps']
            if debug and len(df) != 1440:
                print(f"Check {traffic_file}: {len(df)}")
            all_traffic_rates.extend(df["traffic_rate_bps"].tolist())
        else:
            print(f"Traffic file not found: {traffic_file}")

    if not all_traffic_rates:
        return None, None, None  # No data found

    return (
        int(np.percentile(all_traffic_rates, 90)),
        int(np.percentile(all_traffic_rates, 95)),
        int(np.percentile(all_traffic_rates, 99)),
    )

def log_period_statistics(start_date: datetime, end_date: datetime, scheduling_policy: str, c: int,
                          cloud_compute_rate: float,
                          egress_byte_Presto, ingress_byte_Presto,
                          egress_byte_Spark, ingress_byte_Spark,
                          dir_path, opt_dir_path, traffic_rate_disabled=False,
                          rep_rate=None
                          ):
    """Logs summarized statistics for each period"""

    period_str = f"{start_date.strftime('%Y%m%d')}-{end_date.strftime('%Y%m%d')}"
    cloud_computation_rate_str = f"{cloud_compute_rate*100:.2f}%"  # Get actual ratio from scheduler
    cloud_computation_target = c

    # Extract movement bytes
    if opt_dir_path is None:
        movement_ingress = movement_egress = 0
        rep_size = rep_rate * parse_size("299.12PB")
        sample_rate = 1.0
    else:
        movement_ingress, movement_egress, rep_size, sample_rate = extract_movement_rep_and_sample(os.path.join(opt_dir_path, "log.txt"))

    if not traffic_rate_disabled:
        # Compute traffic percentiles
        if rep_rate is not None:
            print("Evaluating baselines")
            traffic_dir = dir_path
        else:
            traffic_dir = os.path.join(dir_path, f"c{c}")
        p90, p95, p99 = calculate_traffic_percentiles(traffic_dir, start_date, end_date)
    else:
        p90 = p95 = p99 = None

    # Format log line
    log_entry = (f"{period_str},{scheduling_policy},{cloud_computation_rate_str},{cloud_computation_target},"
                 f"{ingress_byte_Presto},{egress_byte_Presto},"
                 f"{ingress_byte_Spark},{egress_byte_Spark},"
                 f"{p90},{p95},{p99},{movement_ingress},{movement_egress},"
                 f"{rep_size},{sample_rate}\n")

    log_file = os.path.join(dir_path, f"log.csv")
    write_header = not os.path.exists(log_file)  # Write header if file does not exist

    try:
        with open(log_file, "a") as f:
            fcntl.flock(f, fcntl.LOCK_EX)  # Lock file to prevent interference
            if write_header:
                f.write(
                    "period,mode,cloud_computation_ratio,cloud_computation_target,"
                    "ingress_byte_Presto,egress_byte_Presto,"
                    "ingress_byte_Spark,egress_byte_Spark,"
                    "P90_traffic_bps,P95_traffic_bps,P99_traffic_bps,"
                    "movement_ingress_bytes,movement_egress_bytes,rep_bytes,sample_rate\n")
            f.write(log_entry)
            fcntl.flock(f, fcntl.LOCK_UN)  # Unlock file after writing

        logging.info(f"Logged period statistics to log.csv: {log_entry.strip()}")

    except Exception as e:
        logging.error(f"Error writing to log.csv: {e}")

def read_Spark(date: datetime):
    print(f"Reading Spark jobs for {date}")
    df = pd.read_csv(f"jobTraces/{date.strftime('%Y%m%d')}-Spark.csv", dtype={
                     'job_id': str, 'start_time': str, 'duration': float,
                    'cputime': float, 'db_name': str, 'table_name': str,
                    #'uown_names': str,
                    'inputDataSize': float,
                    'outputDataSize': float, 'template_id': str
    }, na_values=['\\N'])
    df['db_name'] = df['db_name'].astype(str)
    df['table_name'] = df['table_name'].astype(str)
    df['template_id'] = df['template_id'].astype(str)

    df['table'] = df['db_name'] + '.' + df['table_name']
    df['type'] = JobType.SPARK
    df = df[['job_id', 'start_time', 'duration', 'cputime', 'table', 'inputDataSize',
             'outputDataSize', 'template_id', 'type', 'db_name', 'table_name', 'uown_names']]
    return df

def read_Presto(date: datetime):
    print(f"Reading Presto jobs for {date}")
    df = pd.read_csv(f"jobTraces/{date.strftime('%Y%m%d')}-Presto.csv", dtype={
        'job_id': str, 'start_time': str, 'duration': float,
        'cputime': float, 'db_name': str, 'table_name': str,
        #'uown_names': str,
        'inputDataSize': float,
        'outputDataSize': float, 'template_id': str
    }, na_values=['\\N'])
    df['type'] = JobType.PRESTO
    df['table'] = df['db_name'] + '.' + df['table_name']
    df = df[['job_id', 'start_time', 'duration', 'cputime', 'table', 'inputDataSize',
             'outputDataSize', 'template_id', 'type', 'db_name', 'table_name', 'uown_names']]
    return df

def setup_logger(log_path):
    # Get the root logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # Remove any existing handlers
    if logger.hasHandlers():
        logger.handlers.clear()

    # Create a file handler for logging
    file_handler = logging.FileHandler(log_path, mode='a')
    file_handler.setLevel(logging.INFO)

    # Create a logging format
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)

    # Add the file handler to the logger
    logger.addHandler(file_handler)

def flush_oldest_minute_buckets(minute_buckets, cutoff_minute, dir_path):
    """Writes traffic data to a CSV file and removes old minute data."""
    while minute_buckets and (cutoff_minute is None or next(iter(minute_buckets)) < cutoff_minute):
        minute, traffic = minute_buckets.popitem(last=False)  # Remove oldest entry
        day = minute.date().strftime('%Y%m%d')
        file_path = f"traffic_{day}.csv"

        df_minute = pd.DataFrame([{
            'minute': minute.strftime('%H:%M'),
            'egress_rate_presto_bps': int((traffic['egress_byte_Presto'] * 8) / 60),  # Convert bytes to bits
            'ingress_rate_presto_bps': int((traffic['ingress_byte_Presto'] * 8) / 60),
            'egress_rate_spark_bps': int((traffic['egress_byte_Spark'] * 8) / 60),  # Convert bytes to bits
            'ingress_rate_spark_bps': int((traffic['ingress_byte_Spark'] * 8) / 60),
            # 'egress_rate_bps': int((traffic['egress_byte'] * 8) / 60),  # Convert bytes to bits
            # 'ingress_rate_bps': int((traffic['ingress_byte'] * 8) / 60)
        }])
        # TODO: batch writes to reduce I/O
        df_minute.to_csv(os.path.join(dir_path, file_path),
                         mode='a', index=False, header=not os.path.exists(os.path.join(dir_path, file_path)))

# Used for bug fixing, ignore it
def iterate_logs(opt_path, num_of_week, yugong):
    with open(f"{opt_path}/fix.csv", "w") as f:
        f.write("period,cloud_computation_target,P90_traffic_bps,P95_traffic_bps,P99_traffic_bps,"
                "movement_ingress_bytes,movement_egress_bytes,rep_bytes,sample_rate\n")
        for c in [30, 50, 70]:
            period_day = 7
            period_start = datetime.strptime("2024-10-29", "%Y-%m-%d")
            for period_offset in range(num_of_week):
                start_date = period_start + timedelta(days=period_offset * period_day)
                p90, p95, p99 = calculate_traffic_percentiles(os.path.join(opt_path, f"c{c}"), start_date, start_date + timedelta(days=period_day-1), debug=True)
                if period_offset == 0:
                    label = ""
                else:
                    label = "_" + (start_date - timedelta(days=period_day)).strftime("%m%d")

                if yugong:
                    opt_dir_path = os.path.join(opt_path, f"test_run_c{c}_bw0.20_local{100-c}{label}")
                else:
                    opt_dir_path = os.path.join(opt_path, f"test_run_c{c}_bw0.02_local{100-c}{label}")

                log_file = os.path.join(opt_dir_path, "log.txt")
                if os.path.exists(log_file):
                    movement_ingress, movement_egress, rep_size, sample_rate = extract_movement_rep_and_sample(log_file)
                    f.write(f"{start_date.strftime('%Y%m%d')}-{(start_date + timedelta(days=period_day-1)).strftime('%Y%m%d')},"
                            f"{c},{p90},{p95},{p99},{movement_ingress},{movement_egress},{rep_size},{sample_rate}\n")

def process_yugong(c, num_of_week, dir_path, debug):
    # create c30 or c50 or c70 directory under parent to store traffic_rate
    if os.path.exists(os.path.join(dir_path, f"c{c}")):
        shutil.rmtree(os.path.join(dir_path, f"c{c}"), ignore_errors=False)
    os.makedirs(f"{dir_path}/c{c}", exist_ok=True)
    period_day = 7

    setup_logger(os.path.join(dir_path, f'routing_c{c}.txt'))

    # prepare scheduler
    period_start = datetime.strptime("2024-10-29", "%Y-%m-%d")

    """ to calculate traffic rate per minute, """
    minute_buckets = OrderedDict()  # OrderedDict keeps minute order for easy popping

    # store logs for each period
    period_logs = []

    for period_offset in range(num_of_week):
        start_date = period_start + timedelta(days=period_offset * period_day)
        if period_offset == 0:
            label = ""
        else:
            label = "_" + (start_date - timedelta(days=period_day)).strftime("%m%d")

        # Header: start_time,job_id,template_id,duration,
        # uown_names,inputDataSize,cputime, type
        # TODO: this can be parallelized
        df_presto = pd.concat([read_Presto(start_date + timedelta(days=i)) for i in range(period_day)])
        df_spark = pd.concat([read_Spark(start_date + timedelta(days=i)) for i in range(period_day)])

        df = pd.concat([df_spark, df_presto])
        df['totalDataSize'] = df['inputDataSize'] + df['outputDataSize']
        df = df.sort_values(['start_time', 'job_id'])
        print("first 5 jobs", df.head())
        logging.info(f"Week {period_offset + 1}, starting on {start_date}")
        logging.info(f"# of jobs: {len(df['job_id'].unique())}")


        if debug:
            print(f"debug mode: retain first 3K rows", flush=True)
            # retain first 3K rows
            jobs = df.head(3000).groupby(['start_time', 'job_id'])
            print(jobs.head(1))
        else:
            jobs = df.groupby(['start_time', 'job_id'])

        """ prepare ownership info for query and table """
        ownership = Ownership()
        #print(f"# of unique query ownership after processing: {df['uown_names'].nunique()}", flush=True)

        table_df = pd.read_csv("report-table-size-20241021.csv",
                               dtype={'hive_database_name': str, 'hive_table_name': str, 'uown_names': str},
                               na_values=['\\N'])
        table_df['table'] = table_df['hive_database_name'] + '.' + table_df['hive_table_name']
        for table, uown_names in zip(table_df['table'], table_df['uown_names']):
            if pd.isna(uown_names):  # Check for NaN values
                continue
            # print(f"Table {table} has ownership {uown_names}", flush=True)
            ownership.add_table_ownership(table, uown_names)

        # prepare scheduler with optimization results
        scheduler = Scheduler(dir_path=os.path.join(dir_path, f"test_run_c{c}_bw0.20_local{100-c}{label}"),
                              table_size_path='report-table-size-0907.csv' if start_date < datetime.strptime("2024-05-13", "%Y-%m-%d") else 'report-table-size-20241021.csv',
                              yugong=True, weight_lookup=None, ownership=ownership)

        egress_byte_Presto = 0
        ingress_byte_Presto = 0
        egress_byte_Spark = 0
        ingress_byte_Spark = 0

        # enumerate jobs
        for (start_time, job_id), group in jobs:
            job_type = group['type'].iloc[0]
            if job_type == JobType.SPARK:
                cputime = group['cputime'].iloc[0]
            else:
                cputime = group['cputime'].sum()
            template_id = group['uown_names'].iloc[0]
            table_volume_list = [(row['table'], row['inputDataSize'], row['outputDataSize']) for _, row in group.iterrows()]

            placement_y, egress_byte, ingress_byte = scheduler.place_query(template_id, cputime, table_volume_list,
                                                         policy='size-predict',
                                                         target_cloud_cpu_ratio=c / 100,
                                                         info=start_time)

            if job_type == JobType.SPARK:
                egress_byte_Spark += egress_byte
                ingress_byte_Spark += ingress_byte
            else:
                egress_byte_Presto += egress_byte
                ingress_byte_Presto += ingress_byte

            """ traffic rate """
            duration = group['duration'].iloc[0]
            if job_type == JobType.SPARK:
                tStart = datetime.strptime(start_time, "%Y-%m-%d %H:%M:%S.%f")
            else:
                tStart = datetime.strptime(start_time, "%Y-%m-%d %H:%M:%S")
            tEnd = tStart + timedelta(seconds=duration)
            start_minute = tStart.replace(second=0, microsecond=0)
            end_minute = (tEnd + timedelta(seconds=59)).replace(second=0, microsecond=0)

            # Flush expired minute buckets (older than job_start_minute)
            #flush_oldest_minute_buckets(minute_buckets, start_minute, os.path.join(dir_path, f"c{c}"))

            min = start_minute
            total_minute = (end_minute - start_minute).total_seconds() / 60
            while min < end_minute:
                if min not in minute_buckets:
                    minute_buckets[min] = {'egress_byte_Presto': 0, 'ingress_byte_Presto': 0,
                                           'egress_byte_Spark': 0, 'ingress_byte_Spark': 0}
                if job_type == JobType.SPARK:
                    minute_buckets[min]['egress_byte_Spark'] += egress_byte / total_minute
                    minute_buckets[min]['ingress_byte_Spark'] += ingress_byte / total_minute
                else:
                    minute_buckets[min]['egress_byte_Presto'] += egress_byte / total_minute
                    minute_buckets[min]['ingress_byte_Presto'] += ingress_byte / total_minute
                min += timedelta(minutes=1)

        logging.info(f"Egress {human_readable_size(egress_byte_Presto + egress_byte_Spark)}: "
                     f"Presto {human_readable_size(egress_byte_Presto)}, Spark {human_readable_size(egress_byte_Spark)}")
        logging.info(f"Ingress {human_readable_size(ingress_byte_Presto + ingress_byte_Spark)}: "
                     f"Presto {human_readable_size(ingress_byte_Presto)}, Spark {human_readable_size(ingress_byte_Spark)}")

        # Log period statistics
        period_logs.append({
            "start_date": start_date,
            "end_date": start_date + timedelta(days=period_day-1),
            "scheduling_policy": "size-predict",
            "c": c,
            "cloud_compute_ratio": scheduler.get_cloud_computation_ratio(),  # Store only the ratio
            "egress_byte_Presto": egress_byte_Presto,
            "ingress_byte_Presto": ingress_byte_Presto,
            "egress_byte_Spark": egress_byte_Spark,
            "ingress_byte_Spark": ingress_byte_Spark,
            "dir_path": dir_path,
            "opt_dir_path": os.path.join(dir_path, f"test_run_c{c}_bw0.20_local{100-c}{label}")
        })

    # Flush remaining minute buckets
    flush_oldest_minute_buckets(minute_buckets, None, os.path.join(dir_path, f"c{c}"))

    # Now log all stored period statistics in a single batch
    for log_entry in period_logs:
        log_period_statistics(
            log_entry["start_date"],
            log_entry["end_date"],
            log_entry["scheduling_policy"],
            log_entry["c"],
            log_entry["cloud_compute_ratio"],  # Only store ratio instead of full scheduler object
            log_entry["egress_byte_Presto"],
            log_entry["ingress_byte_Presto"],
            log_entry["egress_byte_Spark"],
            log_entry["ingress_byte_Spark"],
            log_entry["dir_path"],
            log_entry["opt_dir_path"]
        )

if __name__ == "__main__":
    # process_baseline("rep_x_month", "baselines/rep_x_month_c30_rep0.210_20250301_170015",
    #                  1, 30, 0.21)
    if args.yugong:
        process_yugong(args.c, args.num_week, args.opt_path, args.debug)
    else:
        process_jobs(args.c, args.num_week, args.opt_path, args.debug, args.policy,
                     traffic_rate_disabled=args.simple)

    # iterate_logs(args.opt_path, args.num_week, args.yugong)
    # dollar costs (will be included in eval)
    # Moirai_weekly_cost_print(args.opt_path)