import re

import pandas as pd
import random



def to_seconds(interval_str: str):
    # Split the string into day and time components
    day_part, time_part = interval_str.split(' ')

    # Extract the day, and convert it to seconds (1 day = 86400 seconds)
    day_in_seconds = int(day_part) * 86400  # 24*60*60

    # Split the time string into hours, minutes, seconds, and milliseconds
    time_part = time_part.split('.')[0]  # omitting milliseconds for simplicity
    hours, minutes, seconds = map(int, time_part.split(':'))
    # hours, minutes, seconds, milliseconds = map(int, time_part.split(':'))

    # Convert time to seconds
    time_in_seconds = hours * 3600 + minutes * 60 + seconds  # we are omitting milliseconds for simplicity

    return day_in_seconds + time_in_seconds


def to_interval(seconds):
    days = int(seconds // 86400)
    seconds = int(seconds % 86400)
    hours = int(seconds // 3600)
    seconds = int(seconds % 3600)
    minutes = int(seconds // 60)
    seconds = int(seconds % 60)
    return f"{days:02d} {hours:02d}:{minutes:02d}:{seconds:02d}.000"


def get_binary_outcome(p):
    """
    Returns 1 with probability p and 0 with probability (1-p).

    Parameters:
    - p (float): Probability of returning 1. Must be between 0 and 1.

    Returns:
    - int: 1 or 0.
    """
    if not 0 <= p <= 1:
        raise ValueError("p must be between 0 and 1")

    return 1 if random.random() < p else 0


def parse_size(size_str):
    size_str = size_str.strip()
    units = {"B": 1, "KB": 1024, "MB": 1024 ** 2, "GB": 1024 ** 3, "TB": 1024 ** 4, "PB": 1024 ** 5}
    size, unit = re.match(r"([\d.]+)\s*(\w+)", size_str).groups()
    return float(size) * units[unit]


def egress_cost_calculator(egress_bytes, duration_in_months):
    # 0.02$ per GiB
    price_per_gib = 0.02
    return egress_bytes * price_per_gib / 1024 ** 3


""" egress to Internet
def egress_cost_calculator(egress_bytes, time_in_months, tier="Premium"):
    # Pricing tiers in TiB and corresponding cost per GiB in USD
    pricing_tiers = [
        (1, 0.12),  # 0-1 TiB
        (10, 0.11),  # 1-10 TiB
        (float('inf'), 0.08)  # 10+ TiB
    ]

    # Convert bytes to TiB
    egress_tib = egress_bytes / (1024 ** 4)

    # Adjust for the time period
    egress_tib_per_month = egress_tib / time_in_months

    total_cost = 0
    remaining_tib = egress_tib_per_month

    for limit, cost_per_gib in pricing_tiers:
        if remaining_tib <= 0:
            break

        if remaining_tib > limit:
            cost_tib = limit
        else:
            cost_tib = remaining_tib

        total_cost += cost_tib * 1024 * cost_per_gib
        remaining_tib -= cost_tib

    return total_cost * time_in_months
"""


# # Example usage:
# egress_bytes = 5 * 1024 ** 4  # 5 TiB
# time_in_months = 0.25  # 1/4 month
#
# cost = egress_cost_calculator(egress_bytes, time_in_months)
# print(f"The total egress cost is: ${cost:.2f}")


def human_readable_size(size_in_bytes, decimal_places=2):
    if size_in_bytes < 0:
        return "-" + human_readable_size(-size_in_bytes)
    if size_in_bytes < 1024:
        return f"{size_in_bytes}B"
    elif size_in_bytes < 1024 ** 2:
        return f"{size_in_bytes / 1024:.{decimal_places}f}KB"
    elif size_in_bytes < 1024 ** 3:
        return f"{size_in_bytes / (1024 ** 2):.{decimal_places}f}MB"
    elif size_in_bytes < 1024 ** 4:
        return f"{size_in_bytes / (1024 ** 3):.{decimal_places}f}GB"
    elif size_in_bytes < 1024 ** 5:
        return f"{size_in_bytes / (1024 ** 4):.{decimal_places}f}TB"
    else:
        return f"{size_in_bytes / (1024 ** 5):.{decimal_places}f}PB"


# not useful now
class Cache():
    def __init__(self, maxsize: int, kv_store):
        self.kv_store = kv_store  # dict
        self.maxsize = maxsize
        self.cache = {}
        self.queue = []
        self.hits = 0
        self.misses = 0

    def hit_rate(self):
        if self.hits + self.misses == 0:
            return 0
        return self.hits / (self.hits + self.misses)

    def get(self, key, default=None):
        if key in self.cache:
            self.hits += 1
            return self.cache[key]
        self.misses += 1
        if key in self.kv_store:
            value = self.kv_store[key]
            self.put(key, value)
            return value
        return default

    def put(self, key, value):
        if len(self.cache) >= self.maxsize:
            evict_key = self.queue.pop(0)
            self.cache.pop(evict_key)
        self.cache[key] = value
        self.queue.append(key)

"""
Deprecated
since too coarse-grained
"""


class InterNewFPWorkload:
    def __init__(self):
        self.df_list = []
        assert False, "TODO"

    def abFP_to_table_group(self, input_path, output_path=None, persist=False):
        # Step 1: Read the CSV file
        input_df = pd.read_csv(input_path)

        # #Sample data for demonstration purposes
        # data = {
        #     'abstractFingerPrint': ['fp1', 'fp1', 'fp2', 'fp2', 'fp3'],
        #     'db_name': ['db1', 'db2', 'db1', 'db2', 'db1'],
        #     'table_name': ['table1', 'table2', 'table1', 'table2', 'table1'],
        #     'inputDataSize': [10, 20, 30, 40, 90],
        #     'cputime': [10, 10, 40, 40, 90]  # Just adding some cputime values for completeness
        #     'count': [2, 2, 1, 1, 7]
        # }
        # input_df = pd.DataFrame(data)

        # Create a concatenated db.table column
        input_df['db_table'] = input_df['db_name'] + '.' + input_df['table_name']

        # Create a key for unique db.table combinations for each abstractFingerPrint
        input_df['group_key'] = input_df.groupby('abstractFingerPrint')['db_table'].transform(
            lambda x: '.'.join(sorted(set(x))))

        # Create a mapping for the group_key to a new abstractFingerPrint
        mapping = {k: f"newFP{i + 1}" for i, k in enumerate(input_df['group_key'].unique())}
        input_df['new_abstractFingerPrint'] = input_df['group_key'].map(mapping)

        # Now, aggregate inputDataSize and cputime
        agg_df = input_df.groupby(['new_abstractFingerPrint', 'db_name', 'table_name']).agg(
            # {'inputDataSize': 'sum',
            # 'cputime': 'sum'}
            inputDataSize=pd.NamedAgg(column='inputDataSize', aggfunc='sum'),
            cputime=pd.NamedAgg(column='cputime', aggfunc='sum'),
            count=pd.NamedAgg(column='unique_queryid_count', aggfunc='sum'),
            group_key=pd.NamedAgg(column='group_key', aggfunc='first')
        ).reset_index()

        # Drop the intermediary columns if needed
        # agg_df.drop(columns=['new_abstractFingerPrint'], inplace=True, errors='ignore')

        # Output the DataFrame
        # print(agg_df)

        # Step 6: Write to a new CSV file
        if persist:
            self.df_list.append(agg_df)
        if output_path:  # TODO: cannot work with compare()
            # drop 'group_key' column
            agg_df.drop(columns=['group_key'], inplace=True, errors='ignore')
            agg_df.to_csv(output_path, index=False)

    def compare_table_groups(self):
        if not self.df_list:
            print("No data to compare")
            return

        group_key_sets = [set(self.df_list[i]['group_key'].unique()) for i in range(len(self.df_list))]
        last_set = group_key_sets[0]
        total_set = last_set
        common_set = last_set
        print(f"# table groups in df1 {len(last_set)}")
        for i in range(1, len(self.df_list)):
            group_key_set = group_key_sets[i]
            print(f"# table groups in df{i + 1} {len(group_key_set)}")
            new_set = group_key_set.difference(last_set)
            print(f"# table groups in df{i + 1} not in df{i} {len(new_set)}")
            common_set = common_set.intersection(group_key_set)
            print(f"# table groups in df1 to df{i + 1} {len(common_set)} in common")
            never_seen_set = group_key_set.difference(total_set)
            print(f"# table groups never seen before {len(never_seen_set)}")

            # how these tables contribute to the total cputime
            cputime = self.df_list[i]['cputime'].sum()
            cputime_new = self.df_list[i][self.df_list[i]['group_key'].isin(never_seen_set)]['cputime'].sum()
            print(f"Total cputime in df{i + 1} {cputime}, cputime of new table groups {cputime_new}, "
                  f"ratio {cputime_new / cputime * 100:.2f}%")

            total_set = total_set.union(group_key_set)

        print(f"Total # table groups: {len(total_set)}")


# test = InterNewFPWorkload()
# test.abFP_to_table_group('report-abFP-volume-table-0818-0824.csv',
#                          # 'report-newFP-volume-table-0818-0824.csv',
#                          persist=True)
# test.abFP_to_table_group('report-abFP-volume-table-0825-0831.csv',
#                          # 'report-newFP-volume-table-0825-0831.csv',
#                          persist=True)
# test.abFP_to_table_group('report-abFP-volume-table-0901-0907.csv',
#                          # 'report-newFP-volume-table-0901-0907.csv',
#                          persist=True)
# test.abFP_to_table_group('report-abFP-volume-table-0908-0914.csv',
#                          # 'report-newFP-volume-table-0908-0914.csv',
#                          persist=True)
# test.abFP_to_table_group('report-abFP-volume-table-0915-0921.csv',
#                          # 'report-newFP-volume-table-0915-0921.csv',
#                          persist=True)
# test.compare()

"""
Mostly to compare workloads in different time periods
other settings are usually the same
"""


class Whale:
    # big and unpopular tables
    def __init__(self, whale_source_path, all_table_source_path, workload_source_path):
        # header: code,category,db_table,size_in_gbs
        self.whale_df = pd.read_csv(whale_source_path)

        # header: hive_database_name,hive_table_name,dir_size (Bytes)
        self.all_table_df = pd.read_csv(all_table_source_path)
        # into GB
        self.all_table_df['dir_size'] = self.all_table_df['dir_size'] / 1024 ** 3
        # filter our size = 0
        self.all_table_df = self.all_table_df[self.all_table_df['dir_size'] > 0]

        # header: db_name,table_name,inputDataSize (Bytes), cputime (Interval)
        self.workload_df = pd.read_csv(workload_source_path)
        self.workload_df['inputDataSize'] = self.workload_df['inputDataSize'] / 1024 ** 3
        self.workload_df['cputime'] = self.workload_df['cputime'].apply(to_seconds)

        # verify the workload_df
        self.total_inputDataSize = self.workload_df['inputDataSize'].sum()
        self.total_cputime = self.workload_df['cputime'].sum()
        print("Total inputDataSize: ", self.total_inputDataSize)
        print("Total cputime: ", self.total_cputime)

    def print_table_info(self, df_slice, start_rank, end_rank):
        def get_workload_info_for_slice(slice_df):
            # Create a copy of the slice
            slice_df_copy = slice_df.copy()

            # Splitting the db_table column into two columns
            slice_df_copy[['split_db_name', 'split_table_name']] = slice_df_copy['db_table'].str.split('.', expand=True)

            # Merge with workload_df on table names
            merged_df = slice_df_copy.merge(
                self.workload_df,
                left_on=['split_db_name', 'split_table_name'],
                right_on=['db_name', 'table_name'],
                how='left'
            )
            # read_size_ = merged_df['inputDataSize'].sum()
            # total_cputime = merged_df['cputime'].sum()

            cold_tables = merged_df[pd.isna(merged_df['inputDataSize'])]
            cold_table_count = len(cold_tables)
            cold_table_size = cold_tables['size_in_gbs'].sum()

            test_df = slice_df_copy.merge(
                self.workload_df,
                left_on=['split_db_name', 'split_table_name'],
                right_on=['db_name', 'table_name'],
                how='right'
            )
            test_df = test_df[pd.notna(test_df['size_in_gbs'])]

            read_size_ = test_df['inputDataSize'].sum()
            total_cputime = test_df['cputime'].sum()

            return read_size_, cold_table_count, cold_table_size, total_cputime

        for code_val, desc in [(10, 'code=10'), (None, 'code!=10')]:
            curr_slice = df_slice[df_slice['code'] == code_val] if code_val is not None else df_slice[
                df_slice['code'] != 10]
            read_size, cold_count, cold_size, total_cputime = get_workload_info_for_slice(curr_slice)

            print(
                f"{start_rank}~{end_rank} tables with {desc}: Count: {len(curr_slice)} "
                f"| Size: {human_readable_size(curr_slice['size_in_gbs'].sum() * 1024 ** 3)} "
                f"| Read Size: {human_readable_size(read_size * 1024 ** 3)} "
                f"| Total CPU Time: {total_cputime} "
                f"| Cold Tables: {cold_count} | Cold Tables Size: {human_readable_size(cold_size * 1024 ** 3)}"
            )

    def print(self):
        def get_workload_info_for_slice(slice_df, col_db='hive_database_name', col_table='hive_table_name',
                                        col_size='dir_size'):
            # Create a copy of the slice
            slice_df_copy = slice_df.copy()

            # Check if we need to split the db_table column
            if col_db == 'db_table' and col_table == 'db_table':
                slice_df_copy[['db_name', 'table_name']] = slice_df_copy['db_table'].str.split('.', expand=True)
                col_db = 'db_name'
                col_table = 'table_name'

            # Merge with workload_df on table names
            merged_df = slice_df_copy.merge(
                self.workload_df,
                left_on=[col_db, col_table],
                right_on=['db_name', 'table_name'],
                how='left'
            )
            # read_size = merged_df['inputDataSize'].sum()
            # total_cputime = merged_df['cputime'].sum()

            cold_tables = merged_df[pd.isna(merged_df['inputDataSize'])]
            cold_table_count = len(cold_tables)
            cold_table_size = cold_tables[col_size].sum()

            test_df = slice_df_copy.merge(
                self.workload_df,
                left_on=[col_db, col_table],
                right_on=['db_name', 'table_name'],
                how='right'
            )
            test_df = test_df[pd.notna(test_df[col_size])]

            read_size = test_df['inputDataSize'].sum()
            total_cputime = test_df['cputime'].sum()

            return read_size, cold_table_count, cold_table_size, total_cputime

        # 1. Print total count and size of all tables
        all_tables_count = len(self.all_table_df)
        all_tables_size = self.all_table_df['dir_size'].sum()
        _, all_tables_cold_count, all_tables_cold_size, _ = (
            get_workload_info_for_slice(self.all_table_df))
        print(f"Total tables: {all_tables_count} | Total size: {human_readable_size(all_tables_size * 1024 ** 3)} | "
              f"Read Size: {human_readable_size(self.total_inputDataSize * 1024 ** 3)} | "
              f"Total CPU Time: {self.total_cputime} | "
              f"Cold Tables: {all_tables_cold_count} | Cold Tables Size: {human_readable_size(all_tables_cold_size * 1024 ** 3)}")

        # 2. Print count and total size of whale source tables
        top300_count = len(self.whale_df)
        top300_size = self.whale_df['size_in_gbs'].sum()
        top300_read_size, top300_cold_count, top300_cold_size, top300_cputime = (
            get_workload_info_for_slice(
                self.whale_df, col_db='db_table', col_table='db_table', col_size='size_in_gbs'
            ))

        print(f"Whale tables: {top300_count} | Total size: {human_readable_size(top300_size * 1024 ** 3)} | "
              f"Read Size: {human_readable_size(top300_read_size * 1024 ** 3)} | "
              f"Total CPU Time: {top300_cputime} | "
              f"Cold Tables: {top300_cold_count} | Cold Tables Size: {human_readable_size(top300_cold_size * 1024 ** 3)}")

        # 3. Print difference in count & size
        diff_count = all_tables_count - top300_count
        diff_size = all_tables_size - top300_size
        diff_read_size = self.total_inputDataSize - top300_read_size
        diff_cold_count = all_tables_cold_count - top300_cold_count
        diff_cold_size = all_tables_cold_size - top300_cold_size
        diff_cpu_time = self.total_cputime - top300_cputime
        print(f"Other tables: {diff_count} | Total size: {human_readable_size(diff_size * 1024 ** 3)} | "
              f"Read Size: {human_readable_size(diff_read_size * 1024 ** 3)} | "
              f"Total CPU Time: {diff_cpu_time} | "
              f"Cold Tables: {diff_cold_count} | Cold Tables Size: {human_readable_size(diff_cold_size * 1024 ** 3)}")

        # 4. Detailed analysis of whale tables
        sorted_whale_df = self.whale_df.sort_values(by='size_in_gbs', ascending=False)

        # Loop through the table ranges and print info
        ranks = [(0, 100), (100, 200), (200, 300)]
        for start, end in ranks:
            self.print_table_info(sorted_whale_df.iloc[start:end], start, end)

        # pour out the code!=10 tables with db_table,size_in_gbs columns
        sorted_whale_df[sorted_whale_df['code'] != 10][['db_table', 'size_in_gbs']].to_csv('whale_tables.csv',
                                                                                           index=False)

# whale = Whale('whale_tables_source.csv',
#               'report-table-size.csv',
#               'report-abFP-volume-table-0908-0914.csv')
# whale.print()
