import argparse
import os
from datetime import datetime
from enum import Enum
from random import random

import numpy as np
import pandas as pd

from utility import human_readable_size

"""
Eventual goal is to support three baselines:
1. No Rep or Rep recent X months of data, parameter: X
* Tables are randomly split into two sites, but consider on-prem to have storage space reserved for X months of data
* For example, if replicated data is 21% of total data and split is 50% on-prem, 50% cloud, then 
** on-prem has 29% of unique data, cloud has 50% of unique data, and 21% of data are replicated
** and when generating placement, we allocate 50%/121% of data to on-prem and 71%/121% of data to cloud

2. rep top access density tables, parameter: replication budget
* Rank the tables by access density (read+write access size/table size), and replicate the top tables until the replication budget is exhausted
* Other tables are randomly split into two sites
* For example, if replicated data is 5% of total data and split is 50% on-prem, 50% cloud, then
** on-prem has 45% of unique data, cloud has 50% of unique data, and 5% of data are replicated
** and when generating placement, we allocate 45%/95% of data to on-prem and 50%/95% of data to cloud

3. optimized data placement based on Moirai job distribution
* Data should be placed to the site where most jobs depending on them are running
"""

# Define command-line arguments
parser = argparse.ArgumentParser(description="Data placement for baselines")
parser.add_argument("--baseline", type=str, help="baseline to run",
                    choices=["rep_rtd", "volley", "rep_x_month",])
# parser.add_argument("--rep_strategy", type=str, help="scheduling policy",
#                     choices=["rep_x_month", "rep_rtd"])
# parser.add_argument("--placement_strategy", type=str, help="placement strategy",
#                     choices=["random", "volley"])
parser.add_argument("--rep_rate", type=float, help="Pre-selecting replication budget rate, [0, 1]")
parser.add_argument("--c", type=int, default=30, help="Portion of compute to cloud")


args = parser.parse_args()

class Status(Enum):
    ONPREM = 0
    CLOUD = 1
    REP = 2

class baselines:
    def __init__(self, tag: str, cloud_target: int, rep_rate: float):
        self.parent_dir = "baselines"
        os.makedirs(self.parent_dir, exist_ok=True)

        self.tag = tag
        my_dir_name = f"{tag}_c{cloud_target}_rep{rep_rate:.3f}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        self.my_dir = os.path.join(self.parent_dir, my_dir_name)
        os.makedirs(self.my_dir, exist_ok=True)

        self.cloud_target = cloud_target
        assert 0 <= cloud_target <= 100, "cloud_target must be in [0, 100]"

        self.rep_rate = rep_rate

        self.table_size_lookup = self._load_table_size()
        self.total_data_size = sum(self.table_size_lookup.values())
        self.on_prem_data_size = 0
        self.on_prem_capacity = self.total_data_size * (100 - cloud_target) / 100

        # header: abstractFingerPrint,db_name,table_name,inputDataSize,outputDataSize
        self.workload = self._load_workload()

        self.placement = {} # db_table name -> Status
        if tag == "rep_x_month":
            assert rep_rate in [0.21] # 3M
            self.on_prem_data_size = self.total_data_size * rep_rate
            assert self.on_prem_data_size <= self.on_prem_capacity, "Not enough capacity for replication"
            self.rep_x_month_placement()
        elif tag == "rep_rtd":
            self.preselect_replication()
            self.data_placement_random()
        elif tag == "MoiJob":
            self.preselect_replication()
            df = self.get_moirai_job_distribution()
            self.data_placement_by_compute_distribution(df)
            #self.volley_placement()
        elif tag == "volley_new":
            self.preselect_replication()
            df = self.get_random_project_distribution()
            self.data_placement_by_compute_distribution(df)
        else:
            raise ValueError(f"Unknown baseline: {tag}")

        self.persist_placement()

    def _load_table_size(self):
        df = pd.read_csv("report-table-size-20241021.csv")

        required_columns = {"hive_database_name", "hive_table_name", "dir_size"}
        assert required_columns.issubset(df.columns), f"Missing columns: {required_columns - set(df.columns)}"
        df['hive_database_name'] = df['hive_database_name'].astype(str)
        df['hive_table_name'] = df['hive_table_name'].astype(str)
        df['dir_size'] = df['dir_size'].astype(int)
        df = df[df['dir_size'] > 0]

        table_size_lookup = {}
        # random shuffle to avoid bias
        df = df.sample(frac=1)
        for row in df.itertuples():
            lookup_key = row.hive_database_name + "." + row.hive_table_name
            value = row.dir_size
            table_size_lookup[lookup_key] = value

        print("total data size", human_readable_size(df['dir_size'].sum()))
        return table_size_lookup

    def _load_workload(self):
        df_presto = pd.read_csv("newTraces/report-abFP-volume-table-20241022-20241028-Presto.csv")
        df_spark = pd.read_csv("newTraces/report-abFP-volume-table-20241022-20241028-Spark.csv")

        job_data_access_df = pd.concat([df_presto, df_spark], ignore_index=True)
        job_data_access_df.drop(columns=["cputime"], inplace=True)
        return job_data_access_df

    def preselect_replication(self):
        df = self.workload.copy()
        df['access_size'] = df['inputDataSize'] + df['outputDataSize']
        df['db_table'] = df['db_name'] + "." + df['table_name']
        df = df.groupby('db_table', as_index=False)['access_size'].sum()

        df = df[df['db_table'].isin(self.table_size_lookup)]

        df['access_density'] = df['access_size'] / df['db_table'].map(self.table_size_lookup)
        df = df.sort_values(by='access_density', ascending=False)

        total_rep_size = 0
        for _, row in df.iterrows():
            table = row['db_table']
            table_size = self.table_size_lookup.get(table, 0)
            if total_rep_size >= self.total_data_size * self.rep_rate:
                break
            elif total_rep_size + table_size <= self.total_data_size * self.rep_rate + 1024**3: # 1GB buffer
                self.placement[table] = Status.REP
                total_rep_size += table_size
        print(f"replicated data size: {human_readable_size(total_rep_size)}")
        self.on_prem_data_size += total_rep_size

    def rep_x_month_placement(self):
        on_prem_capacity_left = self.on_prem_capacity - self.on_prem_data_size
        cloud_capacity_left = self.total_data_size - self.on_prem_capacity - self.on_prem_data_size
        print(f"on-prem capacity left: {human_readable_size(on_prem_capacity_left)}, "
                f"cloud capacity left: {human_readable_size(cloud_capacity_left)}")
        prob_on_prem = on_prem_capacity_left / (on_prem_capacity_left + cloud_capacity_left)
        placed_tables = self.placement.keys()
        for table, size in self.table_size_lookup.items():
            effective_size = size * (1-self.rep_rate)
            if table in placed_tables:
                continue
            if on_prem_capacity_left > effective_size and cloud_capacity_left > effective_size:
                if random() < prob_on_prem:
                    self.placement[table] = Status.ONPREM
                    on_prem_capacity_left -= effective_size
                else:
                    self.placement[table] = Status.CLOUD
                    cloud_capacity_left -= effective_size
            elif on_prem_capacity_left > effective_size:
                self.placement[table] = Status.ONPREM
                on_prem_capacity_left -= effective_size
            elif cloud_capacity_left > effective_size:
                self.placement[table] = Status.CLOUD
                cloud_capacity_left -= effective_size
            else:
                # print(f"Table {table} ({human_readable_size(size)}) is larger than remaining capacity "
                #       f"on-prem ({human_readable_size(on_prem_capacity_left)}) "
                #       f"and cloud ({human_readable_size(cloud_capacity_left)})")
                self.placement[table] = Status.CLOUD
                cloud_capacity_left -= effective_size
        self.on_prem_data_size = self.on_prem_capacity - on_prem_capacity_left
        print(f"on-prem data size: {human_readable_size(self.on_prem_data_size)}")

    def data_placement_random(self):
        on_prem_capacity_left = self.on_prem_capacity - self.on_prem_data_size
        cloud_capacity_left = self.total_data_size - self.on_prem_capacity
        print(f"on-prem capacity left: {human_readable_size(on_prem_capacity_left)}, "
              f"cloud capacity left: {human_readable_size(cloud_capacity_left)}")
        prob_on_prem = on_prem_capacity_left / (on_prem_capacity_left + cloud_capacity_left)
        placed_tables = self.placement.keys()
        for table, size in self.table_size_lookup.items():
            if table in placed_tables:
                continue
            if on_prem_capacity_left > size and cloud_capacity_left > size:
                if random() < prob_on_prem:
                    self.placement[table] = Status.ONPREM
                    on_prem_capacity_left -= size
                else:
                    self.placement[table] = Status.CLOUD
                    cloud_capacity_left -= size
            elif on_prem_capacity_left > size:
                self.placement[table] = Status.ONPREM
                on_prem_capacity_left -= size
            elif cloud_capacity_left > size:
                self.placement[table] = Status.CLOUD
                cloud_capacity_left -= size
            else:
                # print(f"Table {table} ({human_readable_size(size)}) is larger than remaining capacity "
                #       f"on-prem ({human_readable_size(on_prem_capacity_left)}) "
                #       f"and cloud ({human_readable_size(cloud_capacity_left)})")
                self.placement[table] = Status.CLOUD
                cloud_capacity_left -= size
        self.on_prem_data_size = self.on_prem_capacity - on_prem_capacity_left
        print(f"on-prem data size: {human_readable_size(self.on_prem_data_size)}")

    def get_moirai_job_distribution(self):
        Moirai_path = f"baselines/query_placement_c{self.cloud_target}.csv"
        if not os.path.exists(Moirai_path):
            raise FileNotFoundError(f"Path {Moirai_path} not found")
        # header: abFP,y
        job_dist = pd.read_csv(Moirai_path)
        placement_map = dict(zip(job_dist['abFP'], job_dist['y']))
        df = self.workload.copy()
        df['Status'] = df['abstractFingerPrint'].map(placement_map)
        df = df[df['Status'].notnull()]

        return df

    def get_random_project_distribution(self):
        assert os.path.exists("yugongTraces/report-uown-volume-table-20241022-20241028.csv")
        # header: abstractFingerPrint, db_name, table_name, inputDataSize, outputDataSize, cputime
        df = pd.read_csv("yugongTraces/report-uown-volume-table-20241022-20241028.csv")
        prob_in_cloud = self.cloud_target / 100

        # Get unique abstractFingerPrint values
        unique_abFPs = df['abstractFingerPrint'].unique()
        total_cputime = df['cputime'].sum()
        cloud_quota = total_cputime * prob_in_cloud
        onprem_quota = total_cputime * (1 - prob_in_cloud)

        shuffled_abFPs = np.random.permutation(unique_abFPs)

        abFP_status_map = {}
        for abFP in shuffled_abFPs:
            cputime = df[df['abstractFingerPrint'] == abFP]['cputime'].sum()
            if cputime <= cloud_quota:
                abFP_status_map[abFP] = 1
                cloud_quota -= cputime
            else:
                abFP_status_map[abFP] = 0
                onprem_quota -= cputime

        # # Assign Status randomly with probability prob_in_cloud for being 1 (in cloud)
        # abFP_status_map = {abFP: np.random.choice([0, 1], p=[1 - prob_in_cloud, prob_in_cloud]) for abFP in
        #                    unique_abFPs}

        # Map the assigned Status to the dataframe
        df['Status'] = df['abstractFingerPrint'].map(abFP_status_map)

        return df

    def data_placement_by_compute_distribution(self, df):
        df['access_size'] = df['inputDataSize'] + df['outputDataSize']
        grouped = df.groupby(['db_name', 'table_name', 'Status'], as_index=False)['access_size'].sum()
        grouped_map = {}


        for _, row in grouped.iterrows():
            key = f"{row['db_name']}.{row['table_name']}"
            grouped_map.setdefault(key, {Status.ONPREM: 0, Status.CLOUD: 0})

            if row['Status'] == 0:
                grouped_map[key][Status.ONPREM] = row['access_size']
            else:
                grouped_map[key][Status.CLOUD] = row['access_size']

        onprem_size = 0
        cloud_size = 0
        for table_key, traffic in grouped_map.items():
            onprem_traffic = traffic[Status.ONPREM]
            cloud_traffic = traffic[Status.CLOUD]
            table_size = self.table_size_lookup.get(table_key, 0)
            if table_key in self.placement or table_size == 0:
                continue

            if onprem_traffic > cloud_traffic and self.on_prem_data_size + onprem_size + table_size <= self.on_prem_capacity:
                self.placement[table_key] = Status.ONPREM
                onprem_size += table_size
            else:
                self.placement[table_key] = Status.CLOUD
                cloud_size += table_size

        print(f"on-prem new data size: {human_readable_size(onprem_size)}, "
              f"cloud new data size: {human_readable_size(cloud_size)}")
        self.on_prem_data_size += onprem_size
        print(f"on-prem data size: {human_readable_size(self.on_prem_data_size)}")

        decisions = self.placement.keys()
        for table, table_size in self.table_size_lookup.items():
            if table not in decisions:
                if table_size + self.on_prem_data_size <= self.on_prem_capacity:
                    self.placement[table] = Status.ONPREM
                    self.on_prem_data_size += table_size
                else:
                    self.placement[table] = Status.CLOUD
                    cloud_size += table_size
        print(f"on-prem data size: {human_readable_size(self.on_prem_data_size)}")
        print(f"cloud data size: {human_readable_size(cloud_size)}")
        print(f"total data size: {human_readable_size(self.on_prem_data_size + cloud_size)}")
        print(f"=====================")

        # print(f"# of tables: {len(self.placement)}")
        # total_data_size = 0
        # onprem_size = 0
        # rep_size = 0
        # cloud_size = 0
        # for table, status in self.placement.items():
        #     total_data_size += self.table_size_lookup.get(table, 0)
        #     if status == Status.ONPREM:
        #         onprem_size += self.table_size_lookup.get(table, 0)
        #     elif status == Status.REP:
        #         rep_size += self.table_size_lookup.get(table, 0)
        #     elif status == Status.CLOUD:
        #         cloud_size += self.table_size_lookup.get(table, 0)
        # print("Sanity Check:")
        # print(f"on-prem data size: {human_readable_size(onprem_size)}")
        # print(f"replicated data size: {human_readable_size(rep_size)}")
        # print(f"cloud data size: {human_readable_size(cloud_size)}")
        # print(f"total data size: {human_readable_size(total_data_size)}")



    def persist_placement(self):
        with open(os.path.join(self.my_dir, "dataset_placement.csv"), "w") as f:
            f.write("table,z,w,size\n")
            for table, status in self.placement.items():
                if status == Status.ONPREM:
                    z, w = 0, 1
                elif status == Status.CLOUD:
                    z, w = 1, 0
                else:
                    z, w = 0, 0
                size = self.table_size_lookup.get(table, None)
                if size is None:
                    #print(f"Table {table} not found in table_size_lookup")
                    continue
                f.write(f"{table},{z},{w},{size/1024**3}\n")


if __name__ == "__main__":
    for c in [30, 50, 70]:
        _ = baselines("rep_x_month", c, 0.21)
    for c in [30, 50, 70]:
        for rep_rate in [0, 0.025]: # 0.002, 0.01, 0.025, 0.05
            _ = baselines("rep_rtd", c, rep_rate)
    for c in [30, 50, 70]:
        _ = baselines("MoiJob", c, 0.002)
    for _ in range(6):
        for c in [30, 50, 70]:
            for rep_rate in [0]: # , 0.002, 0.01, 0.025, 0.05
                _ = baselines("volley_new", c, rep_rate)
    # baseline = baselines(args.baseline, args.c, args.rep_rate)