import os
import re

import pandas as pd

from utility import human_readable_size


def merge_similar_rows(df):
    df = df.copy()
    initial_length = len(df)

    df['db_name'] = df['table'].apply(lambda x: x.split('.')[0])
    df['table_name'] = df['table'].apply(lambda x: x.split('.')[1])

    removed_rows = []  # for debug

    for db_name in df['db_name'].unique():
        group_row = df[df['table'] == f'{db_name}.group']

        if group_row.empty or len(group_row) > 1:
            assert len(group_row) <= 1, "Multiple conflicting groups found"
            # print(f"[Merge Rows] {db_name}.group not found")
            continue  # Skip if no group or multiple conflicting groups

        group_z, group_w = group_row['z'].values[0], group_row['w'].values[0]
        if group_z == 0 and group_w == 0:
            print(f"[Merge Rows] {db_name}.group is replicated (z=w=0), skipped")
            continue

        similar_rows = df[(df['db_name'] == db_name) &
                          (df['z'] == group_z) & (df['w'] == group_w) &
                          (df['table'] != f'{db_name}.group')]
        if not similar_rows.empty:
            # Store removed rows before dropping
            removed_rows.append(similar_rows.copy())

            # update group size to add up
            group_size = similar_rows['size'].sum()
            df.loc[group_row.index, 'size'] += group_size
            df = df.drop(similar_rows.index)

    print(f"Merging rows reduces rows from {initial_length} to {len(df)}.")

    second_length = len(df)

    removed_rows.append(df[df['size'] == 0].copy())

    # Drop rows where size is 0
    df = df[df['size'] > 0]

    print(f"Removing rows with size 0 reduces rows from {second_length} to {len(df)}.")
    return df, pd.concat(removed_rows)


class Placement:
    def __init__(self, dir_name, placement_list, total_storage):
        self.dir_name = dir_name
        self.placement_list = placement_list
        self.total_storage_gb = total_storage

    def ingress_egress(self, debug=False):
        results = []

        for i in range(len(self.placement_list) - 1):
            print(self.placement_list[i], "->", self.placement_list[i + 1])

            # header: table, z, w, size
            path_bf = os.path.join(self.dir_name, self.placement_list[i], 'dataset_placement.csv')
            # extract the number after "local" pattern from the path: xxx_local{int}xxxx
            pattern = r"local(\d+)"
            match = re.search(pattern, self.placement_list[i])
            if match:
                before = int(match.group(1)) / 100
            else:
                raise ValueError("Pattern not found")
            match = re.search(pattern, self.placement_list[i + 1])
            if match:
                after = int(match.group(1)) / 100
            else:
                raise ValueError("Pattern not found")

            print("before", before, "vs after", after)

            path_af = os.path.join(self.dir_name, self.placement_list[i + 1], 'dataset_placement.csv')
            # print(path_bf, path_af)
            df_bf = pd.read_csv(path_bf)
            print("# of rows in bf before merge:", len(df_bf))
            df_bf, _ = merge_similar_rows(df_bf)
            print("# of rows in bf after merge:", len(df_bf))
            df_af = pd.read_csv(path_af)
            print("# of rows in af before merge:", len(df_af))
            df_af = df_af[df_af['size'] > 0]  # Ensure no zero sizes before merging
            # df_af, _ = merge_similar_rows(df_af)
            print("# of rows in af after merge:", len(df_af))

            # Sort data frames by the 'table' column to ensure matching rows can be compared
            df_bf = df_bf.sort_values('table').reset_index(drop=True)
            df_af = df_af.sort_values('table').reset_index(drop=True)

            if not df_bf['table'].equals(df_af['table']):
                new_tables = df_af[~df_af['table'].isin(df_bf['table'])]

                group_pattern = r"(.+)\.group"
                table_pattern = r"(.+)\.(.+)"

                for index, row in df_bf.iterrows():
                    match = re.match(group_pattern, row['table'])
                    if match:
                        db_name = match.group(1)
                        af_row = df_af[df_af['table'] == row['table']]
                        if not af_row.empty:
                            df_bf.at[index, 'size'] = af_row['size'].values[0]

                new_entries = []
                for index, row in new_tables.iterrows():
                    match = re.match(table_pattern, row['table'])
                    if match:
                        db_name = match.group(1)
                        group_row = df_bf[df_bf['table'] == f"{db_name}.group"]
                        if not group_row.empty:
                            row['z'] = group_row['z'].values[0]
                            row['w'] = group_row['w'].values[0]
                            new_entries.append(row)
                        else:
                            new_entries.append(row)
                    else:
                        new_entries.append(row)

                new_entries_df = pd.DataFrame(new_entries)
                df_bf = pd.concat([df_bf, new_entries_df], ignore_index=True)
                df_bf = df_bf.sort_values('table').reset_index(drop=True)

                # debug
                if not df_bf['table'].equals(df_af['table']):
                    print("[Debug] Tables in bf but not in af:")
                    print(df_bf[~df_bf['table'].isin(df_af['table'])])
                    df_bf.drop(df_bf[~df_bf['table'].isin(df_af['table'])].index, inplace=True)
                    df_bf = df_bf.sort_values('table').reset_index(drop=True)
                    assert df_bf['table'].equals(df_af['table']), f"Tables do not match {len(df_bf)} {len(df_af)}"

            # Calculate ingress and egress
            ingress = (df_bf['z'] == 0) & (df_bf['w'] == 1) & (df_af['w'] == 0)
            egress = (df_bf['z'] == 1) & (df_bf['w'] == 0) & (df_af['z'] == 0)

            # if debug:
            #     # Print the tables that are being ingressed and egressed
            #     print("Ingressed tables:")
            #     print(df_bf[ingress])
            #     print("Egressed tables:")
            #     print(df_bf[egress])

            # Compute weighted ingress and egress
            ingress_count = (ingress * df_bf['size']).sum()
            egress_count = (egress * df_bf['size']).sum()

            # Save the results for this pair
            results.append({
                'from': self.placement_list[i],
                'to': self.placement_list[i + 1],
                'ingress': human_readable_size(ingress_count * 1024 ** 3),
                'egress': human_readable_size(egress_count * 1024 ** 3)
            })

            print(f"From {self.placement_list[i]} to {self.placement_list[i + 1]}:"
                  f" ingress: {human_readable_size(ingress_count * 1024 ** 3)},"
                  f" egress: {human_readable_size(egress_count * 1024 ** 3)},"
                  f" remove {human_readable_size((self.total_storage_gb * 1024 ** 3 * (before - after)))} from onprem")

        print("Results:", results)

        return results


p = Placement('long_term',
              ['test_run_c10_bw0.02_local90',
               'test_run_c20_bw0.02_local80',
               'test_run_c30_bw0.02_local70',
               'test_run_c40_bw0.02_local60',
               'test_run_c50_bw0.02_local50',
               'test_run_c60_bw0.02_local40',
               'test_run_c70_bw0.02_local30',
               'test_run_c80_bw0.02_local20',
               'test_run_c90_bw0.02_local10',
               ],
              total_storage=299.12 * 1024 ** 2)
p.ingress_egress()