import pandas as pd
import numpy as np
from gurobipy import Model, GRB, LinExpr
import os
import time

from utility import human_readable_size, to_seconds
import sys

from Yugong.Ownership import Ownership

from collections import defaultdict

def print_time(start_time, end_time, msg, file=sys.stdout):
    print(msg, f"{end_time - start_time:.2f}", "seconds", file=file, flush=True)

def merge_similar_rows(df, yugong: bool):
    df = df.copy()
    initial_length = len(df)

    df['db_name'] = df['table'].apply(lambda x: x.split('.')[0])
    df['table_name'] = df['table'].apply(lambda x: x.split('.')[1])

    removed_rows = [] # for debug

    if not yugong:
        for db_name in df['db_name'].unique():
            group_row = df[df['table'] == f'{db_name}.group']

            if group_row.empty or len(group_row) > 1:
                assert len(group_row) <= 1, "Multiple conflicting groups found"
                print(f"[Merge Rows] {db_name}.group not found")
                continue  # Skip if no group or multiple conflicting groups

            group_z, group_w = group_row['z'].values[0], group_row['w'].values[0]
            if group_z == 0 and group_w == 0:
                print(f"[Merge Rows] {db_name}.group is replicated (z=w=0), skipped")
                continue

            similar_rows = df[(df['db_name'] == db_name) &
                                (df['z'] == group_z) & (df['w'] == group_w) &
                                (df['table'] != f'{db_name}.group')]
            if not similar_rows.empty:
                # Store removed rows before dropping
                removed_rows.append(similar_rows.copy())

                # update group size to add up
                group_size = similar_rows['size'].sum()
                df.loc[group_row.index, 'size'] += group_size
                df = df.drop(similar_rows.index)

        print(f"Merging rows reduces rows from {initial_length} to {len(df)}.")

    second_length = len(df)

    removed_rows.append(df[df['size'] == 0].copy())

    # Drop rows where size is 0
    df = df[df['size'] > 0]

    print(f"Removing rows with size 0 reduces rows from {second_length} to {len(df)}.")
    return df, pd.concat(removed_rows)


class Query_on_DB_Table:
    def __init__(self, job_data_access_df,
                 workload_print_info,
                 db_table_size_file_name,
                 rep_threshold=None,
                 rep_strategy="",
                 yugong=False,
                 ownership=None,
                 rep_list=None,
                 k=1.0,
                 log_dir='.',
                 ):
        self.df = None
        self.previous_placement_path = None
        self.prev_z = None
        self.prev_w = None

        self.db_table_num = None
        self.abFP_num = None
        self.total_storage_gb = None
        self.dataset_num = None

        self.job_data_access_df = job_data_access_df
        self.workload_print_info = workload_print_info
        print("workload info", workload_print_info)
        self.db_table_size_path = db_table_size_file_name
        print("db_table_size path", db_table_size_file_name)

        self.rep_threshold = rep_threshold
        self.rep_strategy = rep_strategy
        self.rep_constr = []
        self.rep_list = []

        self.yugong = yugong
        self.ownership = None
        if yugong:
            print("** Yugong mode **")
            assert ownership is not None, "Ownership must be provided in Yugong mode"
            assert rep_list is not None, "Replication list must be provided in Yugong mode (to align with Moirai)"
            self.ownership = ownership
            self.rep_list = rep_list

        self.s = None
        self.adj_list_input = defaultdict(dict)  # Maps table id -> {job id: input_size}
        self.adj_list_output = defaultdict(dict)  # Maps table id -> {job id: output_size}
        self.c = None

        self.k = k
        self.X_scale = 0
        self.load_workload()
        assert self.X_scale > 0

        self.df_table_size = None
        self.prepare_db_table_size(db_table_size_file_name)
        assert self.df_table_size is not None

        self.unique_abFP = {}
        self.unique_db_tables = self.prepare_replication()
        self.prepare_workload()

        self.model = None  # gurobi model
        self.prepare_basic_model()
        self.y, self.z, self.w, self.u, self.v = (None, None, None, None, None)

        self.log_dir = log_dir
        self.add_y_z_w_u_v(self.abFP_num, self.dataset_num, True, log_dir=log_dir)
        self.workload_constrs = self.add_workload_constr()


    def __del__(self):
        if hasattr(self, 'model') and self.model is not None:
            self.model.dispose()

    def restore_unique_db_tables(self, file_path, log_dir=None):
        prev_placement_df = pd.read_csv(file_path)
        prev_placement_df['table'] = prev_placement_df['table'].astype(str)
        prev_placement_df, removed_df = merge_similar_rows(prev_placement_df, self.yugong)
        if log_dir is not None:
            removed_df.to_csv(os.path.join(log_dir, 'removed_rows.csv'), index=False)
        print("clean up unique_db_tables", len(self.unique_db_tables), "to", len(prev_placement_df))
        self.unique_db_tables = {}
        count = 0
        for _, row in prev_placement_df.iterrows():
            db_table = row['table']
            self.unique_db_tables[db_table] = count
            count += 1

    def update_workload(self, job_data_access_df, workload_print_info,log_dir=None):
        # TODO: db_table_size_file_name is not updated
        print("workload path from", self.workload_print_info, "to", workload_print_info)
        self.job_data_access_df = job_data_access_df
        self.workload_print_info = workload_print_info

        # TODO: no decaying, just completely replace

        # Updated: self.df, self.abFP_num (wrong), self.db_table_num (wrong), self.c
        self.load_workload()

        self.unique_abFP = {}  # reset abFP
        self.prepare_workload()
        # self.visualize_a()  # debug

        # TODO: omit update replication for now (old constraints still there)

        # update gurobi variables
        self.model.remove(self.workload_constrs)
        self.update_y_z_w_u_v(self.abFP_num, self.dataset_num, binary=True, log_dir=log_dir)
        self.workload_constrs = self.add_workload_constr()
        self.model.update()

    def load_workload(self):
        # format: abstractFingerPrint,db_name,table_name,
        # inputDataSize, outputDataSize (in bytes),
        # cputime (in seconds)
        self.df = self.job_data_access_df
        self.df['totalDataSize'] = self.df['inputDataSize'] + self.df['outputDataSize']

        k = self.k
        assert 0 < k <= 1, f"Top {k} jobs do not satisfy 0 < k <= 1"
        if k < 1:
            # Calculate total inputDataSize for each abstractFingerPrint
            abFP_sizes = self.df.groupby('abstractFingerPrint')['totalDataSize'].sum()

            # Sort abFPs by inputDataSize in descending order
            abFP_sizes = abFP_sizes.sort_values(ascending=False)

            # Determine the top k% of abFPs to keep
            top_k_count = int(len(abFP_sizes) * k)
            top_k_abFPs = abFP_sizes.head(top_k_count).index

            # Filter the dataframe to include only the selected top k abFPs
            self.df = self.df[self.df['abstractFingerPrint'].isin(top_k_abFPs)]

            # Calculate the percentage of accesses retained
            total_access_size = abFP_sizes.sum()
            top_k_access_size = abFP_sizes.loc[top_k_abFPs].sum()
            percent_access_size = (top_k_access_size / total_access_size) * 100

            print(f"Top {k * 100:.2f}% of abFPs (# {top_k_count}) "
                  f"contribute {percent_access_size:.2f}% of total read/write accesses")
            self.X_scale = percent_access_size / 100
        else:
            self.X_scale = 1

        assert self.df is not None
        row_num = len(self.df)
        self.abFP_num = self.df['abstractFingerPrint'].nunique()
        self.db_table_num = self.df.groupby(['db_name', 'table_name']).ngroups
        print("# row", row_num)
        print("# db_table in workload", self.db_table_num)
        print("# abFP", self.abFP_num)

        self.c = self.df.groupby('abstractFingerPrint')['cputime'].sum()
        print("c created")

    def prepare_db_table_size(self, file_name):
        # we care hive_database_name,hive_table_name,dir_size
        self.df_table_size = pd.read_csv(file_name)
        self.df_table_size['hive_database_name'] = self.df_table_size['hive_database_name'].astype(str)
        self.df_table_size['hive_table_name'] = self.df_table_size['hive_table_name'].astype(str)

        # how many lines with dir_size != 0
        print("# of lines with dir_size != 0", self.df_table_size[self.df_table_size['dir_size'] != 0].shape[0])
        # filtered out =0
        self.df_table_size = self.df_table_size[self.df_table_size['dir_size'] > 0]

        # how many lines with input_data_size == 0
        print("# of lines with input_data_size == 0", self.df[self.df['inputDataSize'] == 0].shape[0])
        # dir_size into GB
        self.df_table_size['dir_size'] = self.df_table_size['dir_size'] / 1024 ** 3
        # the total size of the tables
        self.total_storage_gb = self.df_table_size['dir_size'].sum()
        print("total size of the tables", human_readable_size(self.total_storage_gb * 1024 ** 3))

    def prepare_replication(self):
        def _compute_edges():
            edge_per_table = defaultdict(int)
            for _, row in self.df.iterrows():
                idx = f"{row['db_name']}.{row['table_name']}"
                edge_per_table[idx] += 1
            return edge_per_table
        def _compute_access_size():
            access_per_table = defaultdict(int)
            for _, row in self.df.iterrows():
                idx = f"{row['db_name']}.{row['table_name']}"
                access_per_table[idx] += row['inputDataSize'] + row['outputDataSize']
            return access_per_table
        def _load_table_size():
            size_per_table = {}
            for _, row in self.df_table_size.iterrows():
                idx = f"{row['hive_database_name']}.{row['hive_table_name']}"
                size_per_table[idx] = row['dir_size']
            return size_per_table

        assert self.df is not None
        print("Replication strategy:", self.rep_strategy)
        print("No matter read or write or both, we count as 1 edge")
        # print("Default: rank tables by # edges normalized table size (JAD)")
        if self.rep_strategy == "job_access_density":
            edges = _compute_edges()
            size_per_table = _load_table_size()
            metric_per_table = {k: edges[k] / size_per_table[k]
                            if k in size_per_table and size_per_table[k] > 0 else 0 for k in edges}
        elif self.rep_strategy == "read_traffic_volume":
            metric_per_table = _compute_access_size()
        elif self.rep_strategy == "inverse_dataset_size":
            edges = _compute_edges() # used only for the keys
            size_per_table = _load_table_size()
            metric_per_table = {k: 1/size_per_table[k] if k in size_per_table and size_per_table[k] > 0 else 0 for k in edges}
            #metric_per_table = {k: 1 / v if v > 0 else 0 for k, v in size_per_table.items()}
        elif self.rep_strategy == "job_access_frequency":
            metric_per_table = _compute_edges()
        elif self.rep_strategy == "read_traffic_density":
            access_size = _compute_access_size()
            size_per_table = _load_table_size()
            metric_per_table = {k: access_size[k] / size_per_table[k]
                            if k in size_per_table and size_per_table[k] > 0 else 0 for k in access_size}
        else:
            raise ValueError(f"Unknown replication strategy: {self.rep_strategy}")

        sorted_table = {k: v for k, v in sorted(metric_per_table.items(), key=lambda item: -item[1])}
        unique_db_tables = {idx: rank for rank, idx in enumerate(sorted_table)}

        return unique_db_tables

    def prepare_workload(self):
        """ clean_df() deprecated
        def clean_df():
            # remove lines in self.df where db.table has no size

            print("\ncleaning df start...")
            # Create a set of valid db.table combinations from df_table_size
            valid_db_tables = set(
                self.df_table_size.apply(lambda row: f"{row['hive_database_name']}.{row['hive_table_name']}", axis=1))

            initial_num_rows = len(self.df)
            initial_total_input_size = self.df['inputDataSize'].sum()
            initial_total_cputime = self.df['cputime'].sum()

            removed_rows = []

            for index, row in self.df.iterrows():
                db_table = f"{row['db_name']}.{row['table_name']}"
                if db_table not in valid_db_tables:
                    removed_rows.append(row)

            # Remove the rows from the dataframe
            self.df = self.df.drop([row.name for row in removed_rows])

            # Calculate remaining
            remaining_total_input_size = self.df['inputDataSize'].sum()
            remaining_total_cputime = self.df['cputime'].sum()

            print(f"# of lines removed: {len(removed_rows)} vs {initial_num_rows}")
            # print(
            #     f"# of unique tables involved in removal: {len(set(f'{row['db_name']}.{row['table_name']}' for row in removed_rows))}")
            print(f"Total CPU time influenced: {initial_total_cputime - remaining_total_cputime:.0f}"
                  f" ({(initial_total_cputime - remaining_total_cputime)/initial_total_cputime*100:.1f}%)")
            print(f"Total inputDataSize influenced: {human_readable_size(initial_total_input_size-remaining_total_input_size)}"
                  f" vs {human_readable_size(initial_total_input_size)}")

            self.abFP_num = self.df['abstractFingerPrint'].nunique()
            self.db_table_num = self.df.groupby(['db_name', 'table_name']).ngroups
            print("[After clean] # abFP", self.abFP_num)
            print("[After clean] # db_table in workload", self.db_table_num)

            self.c = self.df.groupby('abstractFingerPrint')['cputime'].sum()
            print("[After clean] compute c", self.c)

            print("cleaning df end...\n", flush=True) """

        def prepare_unique_abFP():
            counter_a = len(self.unique_abFP)  # old
            counter_t = len(self.unique_db_tables)  # old
            print(f"from {counter_a} x {counter_t}", end="")
            for index, row in self.df.iterrows():
                i_string = row['abstractFingerPrint']
                j_string = f"{row['db_name']}.{row['table_name']}"
                if i_string not in self.unique_abFP:
                    self.unique_abFP[i_string] = counter_a
                    counter_a += 1
                if j_string not in self.unique_db_tables:
                    self.unique_db_tables[j_string] = counter_t
                    counter_t += 1
            assert counter_a == len(self.unique_abFP) and counter_t == len(self.unique_db_tables)
            print(f" to {counter_a} x {counter_t}")

            self.abFP_num = len(self.unique_abFP)
            self.db_table_num = len(self.unique_db_tables)

            self.adj_list_input = defaultdict(dict)
            self.adj_list_output = defaultdict(dict)

            for index, row in self.df.iterrows():
                i_string = row['abstractFingerPrint']
                j_string = f"{row['db_name']}.{row['table_name']}"  # This needs proper mapping
                i = self.unique_abFP[i_string]
                j = self.unique_db_tables[j_string]
                self.adj_list_input[j][i] = row['inputDataSize'] / 1024 ** 3  # convert to GB
                self.adj_list_output[j][i] = row['outputDataSize'] / 1024 ** 3

        # if clean:
        #     clean_df()

        prepare_unique_abFP()

        counter_t = len(self.unique_db_tables)
        pair_num = len(self.df)
        print("[sanity check] # of non-zero edges [i,j]", pair_num, flush=True)
        # print("should no more than # of rows in self.df", len(self.df), flush=True)

        # expand s variable to all db_table
        assert self.df_table_size is not None
        # Identify the db_tables present in self.df_table_size
        db_tables_in_df = set(
            self.df_table_size.apply(lambda row: f"{row['hive_database_name']}.{row['hive_table_name']}", axis=1))

        # Find the db_tables that are in self.df_table_size but not in self.unique_db_tables
        extra_db_tables = db_tables_in_df - set(self.unique_db_tables.keys())

        if self.yugong:
            self.df_table_size['project'] = self.df_table_size.apply(
                lambda row: self.ownership.get_table_ownership(f"{row['hive_database_name']}.{row['hive_table_name']}"), axis=1)
            missing_sizes = self.df_table_size[self.df_table_size.apply(
                lambda row: f"{row['hive_database_name']}.{row['hive_table_name']}" in extra_db_tables, axis=1)]
            grouped_sizes = missing_sizes.groupby('project')['dir_size'].sum()
            group_num = len(grouped_sizes)
            print("# of grouped untouched projects this time period", group_num, flush=True, end=' ')
            for project in grouped_sizes.index:
                j_string = f"{project}.group"
                self.ownership.add_table_ownership(j_string, project)
                assert self.ownership.get_table_ownership(j_string) == project, f"Ownership not set for {j_string}"
        else:
            # Filter self.df_table_size to keep only the extra db_tables, then group by hive_database_name and sum the sizes
            missing_sizes = self.df_table_size[self.df_table_size.apply(
                lambda row: f"{row['hive_database_name']}.{row['hive_table_name']}" in extra_db_tables, axis=1)]
            grouped_sizes = missing_sizes.groupby('hive_database_name')['dir_size'].sum()
            group_num = len(grouped_sizes)
            print("# of grouped untouched dbs this time period", group_num, flush=True, end=' ')

        count_bf = len(self.unique_db_tables)
        # allocate id for untouched dbs
        for db_name in grouped_sizes.index:
            j_string = f"{db_name}.group"
            if j_string not in self.unique_db_tables:
                self.unique_db_tables[j_string] = counter_t
                counter_t += 1
        print("(only", counter_t - count_bf, "newly appeared dbs)", flush=True)

        # self.dataset_num = self.db_table_num + group_num
        assert counter_t == len(self.unique_db_tables), f"{counter_t} != {len(self.unique_db_tables)}"
        self.dataset_num = counter_t
        for table_id in range(self.db_table_num, self.dataset_num):
            self.adj_list_input[table_id] = {}
            self.adj_list_output[table_id] = {}
        print(f"adjacency lists have tables {len(self.adj_list_input)}", flush=True)

        self.s = np.zeros(self.dataset_num)
        start = time.time()

        size_lookup = {
            (row.hive_database_name, row.hive_table_name): row.dir_size
            for row in self.df_table_size.itertuples(index=False)
        }

        sum_hot_gb = 0
        # set_size = self.df_table_size['dir_size'].min()
        set_size = 0
        for db_table in self.unique_db_tables:
            parts = db_table.split('.')
            if len(parts) != 2:
                raise ValueError(f"Invalid db_table format: {db_table} into {parts}")
            db_name, table_name = parts
            size = size_lookup.get((db_name, table_name), None)
            if size is None or size == 0:
                if db_name in grouped_sizes.index:
                    continue  # delayed to next code block
                # print("Warning: db_table not found in table size file", db_table, "set to", set_size)
                self.s[self.unique_db_tables[db_table]] = set_size
                sum_hot_gb += set_size
            else:
                self.s[self.unique_db_tables[db_table]] = size
                sum_hot_gb += size
        print(f"Touched data size: {human_readable_size(sum_hot_gb * 1024 ** 3)}")

        # cold dataset that is not in the workload
        sum_gb = 0
        for db_name in map(str, grouped_sizes.index):
            j_string = f"{db_name}.group"
            self.s[self.unique_db_tables[j_string]] = grouped_sizes[db_name]
            sum_gb += grouped_sizes[db_name]
        print(f"Non-touched data size: {human_readable_size(sum_gb * 1024 ** 3)}")
        print_time(start, time.time(), "s created")

    def prepare_basic_model(self):
        if self.model is None:
            self.model = Model("query_on_db_table")
            if self.model is None:
                print('Could not create model')
                return None

        # Set Threads to the maximum number of processors
        self.model.setParam('Threads', 12)

        self.model.setParam("NodefileStart", 0.1)  # Start writing to disk at 10% of RAM usage
        self.model.setParam("NodefileDir", ".")
        self.model.setParam("MemLimit", 64 * 1024)  # 64 GB

    def add_y_z_w_u_v(self, N, M, binary=True, log_dir='.', file=sys.stdout):
        u = {}
        v = {}

        pair_num = len(self.df)
        temp_counter = 0
        step = pair_num // 10
        start = time.time()
        print("init y,z,w", file=file, flush=True)

        assert binary is True
        # Decision variables
        # y[i]: If workload i is executed in D1, y[i] = 0; if in D2, y[i] = 1.
        y = self.model.addVars(N, vtype=GRB.BINARY, name='y')

        # z[i]: If dataset j exists in D1, z[j] = 0; else, z[j] = 1.
        z = self.model.addVars(M, vtype=GRB.BINARY, name='z')

        # w[i]: If dataset j exists in D2, w[j] = 0; else, w[j] = 1.
        w = self.model.addVars(M, vtype=GRB.BINARY, name='w')

        print_time(start, time.time(), "y,z,w created", file=file)

        if self.yugong:
            print("Enforce replication on the same set of tables with Moirai")
            print("rep_threshold", self.rep_threshold)
            # print(f"rep_threshold={self.rep_threshold} < 0, z+w=1, no replication")

            replicated_indices = set()
            for table in self.rep_list:
                j = self.unique_db_tables.get(table, None)
                if j is None:
                    print(f"[Warning] Table {table} not found in unique_db_tables, skip for replication")
                    continue

                self.rep_constr.append(self.model.addConstr(z[j] == 0, name=f'z_{j}_0'))
                self.rep_constr.append(self.model.addConstr(w[j] == 0, name=f'w_{j}_0'))
                replicated_indices.add(j)

            # # First, replicate 0.004 of data
            # threshold_gb = self.total_storage_gb * 0.004
            # print(f"Replicate 0.4% data = {human_readable_size(threshold_gb * 1024 ** 3)}")
            # total_size_gb, rep_count = 0, 0
            # replicated_indices = set()
            # for j in range(M):
            #     if self.s[j] > 0:
            #         if total_size_gb + self.s[j] >= threshold_gb + 1024:  # 1TB buffer
            #             continue
            #         total_size_gb += self.s[j]
            #         self.rep_constr.append(self.model.addConstr(z[j] == 0, name=f'z_{j}_0'))
            #         self.rep_constr.append(self.model.addConstr(w[j] == 0, name=f'w_{j}_0'))
            #         replicated_indices.add(j)
            #         rep_count += 1
            #         for key in self.unique_db_tables:
            #             if j == self.unique_db_tables[key]:
            #                 self.rep_list.append(key)
            #     if total_size_gb >= threshold_gb:
            #         print(
            #             f"Replicated Total {rep_count} Tables of {human_readable_size(total_size_gb * 1024 ** 3)} till pos {j}")
            #         print(f"# of replicated tables logged down:", len(self.rep_list))
            #         break

            #Then enforce existance of each tables
            for j in range(M):
                if j not in replicated_indices:
                    self.model.addConstr(z[j] + w[j] <= 1, name=f'zw_{j}')
        else:
            if self.rep_threshold is not None and self.rep_threshold < 0:
                print(f"rep_threshold={self.rep_threshold} < 0, z+w=1, no replication")
                for j in range(M):
                    self.model.addConstr(z[j] + w[j] == 1, name=f'zw_{j}')
            else:
                print(f"rep_threshold={self.rep_threshold}, z+w<=1, replication allowed")
                for j in range(M):
                    self.model.addConstr(z[j] + w[j] <= 1, name=f'zw_{j}')

            if self.rep_threshold is not None and self.rep_threshold > 0:
                assert len(self.rep_list) == 0
                threshold_gb = self.total_storage_gb * self.rep_threshold
                print(
                    f"replicate top {self.rep_threshold * 100:.2f}% data = {human_readable_size(threshold_gb * 1024 ** 3)}")
                counter, total_size_gb, rep_count = 0, 0, 0
                for j in range(M):
                    if self.s[j] > 0:
                        if total_size_gb + self.s[j] >= threshold_gb + 1024:  # 1TB
                            continue
                        total_size_gb += self.s[j]
                        self.rep_constr.append(self.model.addConstr(z[j] == 0, name=f'z_{j}_0'))
                        self.rep_constr.append(self.model.addConstr(w[j] == 0, name=f'w_{j}_0'))
                        rep_count += 1
                        for key in self.unique_db_tables:
                            if j == self.unique_db_tables[key]:
                                self.rep_list.append(key)
                                break
                    if total_size_gb >= threshold_gb:
                        print(f"Replicate Total {rep_count} Tables of {human_readable_size(total_size_gb * 1024 ** 3)} till pos {j}")
                        print(f"# of replicated tables logged down:", len(self.rep_list))
                        with open(os.path.join(log_dir, f"replicated_tables_{self.rep_threshold}_{self.rep_strategy}.csv"), 'w') as f:
                            f.write("replicated_tables\n")
                            for key in self.rep_list:
                                f.write(f"{key}\n")
                        break

        start = time.time()
        print("If sample rate < 1, progress is usually under-estimated, nothing wrong happened")
        print("init u, v", file=file, flush=True)
        replicated_indices = [self.unique_db_tables[key] for key in self.rep_list]

        for j in range(M):
            if j in replicated_indices:
                continue
            job_ids = self.adj_list_input[j].keys()
            for i in job_ids:
                u[(i, j)] = self.model.addVar(vtype=GRB.BINARY, name=f'u_{i}_{j}')
                v[(i, j)] = self.model.addVar(vtype=GRB.BINARY, name=f'v_{i}_{j}')
                temp_counter += 1
                if step != 0 and temp_counter % step == 0:
                    print(f"== progress:{temp_counter / pair_num * 100:.0f}%", file=file, flush=True)

        print_time(start, time.time(), "u, v created", file=file)

        if self.yugong:
            project_list = []
            # print("unique_db_tables", self.unique_db_tables, file=file, flush=True) # debug
            for key in self.unique_db_tables:
                table_ownership = self.ownership.get_table_ownership(key)
                # assert table_ownership is not None, f"Table ownership not found for {key}"
                if table_ownership not in project_list:
                    project_list.append(table_ownership)
            for key in self.unique_abFP:
                query_ownership = key
                if query_ownership not in project_list:
                    project_list.append(query_ownership)
            print("# of projects", len(project_list), file=file, flush=True)
            print("project_list", project_list, file=file, flush=True)  # debug

            # init a dict for each project to store the jobs in the project
            project_jobs = {project: [] for project in project_list}
            for key in self.unique_abFP:
                i = self.unique_abFP[key]
                project = key
                project_jobs[project].append(i)
            # Add constraints for each project to ensure all jobs are either on-premises or in the cloud
            project_vars = {project: self.model.addVar(vtype=GRB.BINARY, name=f'y_project_{project}') for project in project_list}
            for project in project_jobs:
                project_var = project_vars[project]
                for i in project_jobs[project]:
                    self.model.addConstr(y[i] == project_var, name=f'y_project_{project}_{i}')

            project_tables = {project: [] for project in project_list}
            for key in self.unique_db_tables:
                j = self.unique_db_tables[key]
                project = self.ownership.get_table_ownership(key)
                project_tables[project].append(j)
            # Add constraints for each project to ensure all tables are either on-premises or in the cloud
            for project in project_tables:
                project_var = project_vars[project]
                for j in project_tables[project]:
                    # Note that we do not need to worry about replicated tables because the logic is that
                    # If a table is replicated, it always satisfies such constraint
                    # where the project locates, this table should be available
                    self.model.addGenConstrIndicator(project_var, 0, z[j] == 0, name=f'z_project_{project}_{j}')
                    self.model.addGenConstrIndicator(project_var, 1, w[j] == 0, name=f'w_project_{project}_{j}')

        self.y, self.z, self.w, self.u, self.v = (y, z, w, u, v)

    def update_y_z_w_u_v(self, N, M, binary=True, log_dir=None):
        os.makedirs(log_dir, exist_ok=True)
        file = open(os.path.join(log_dir, "log.txt"), 'a')

        old_N = len(self.y) if hasattr(self, 'y') else 0
        old_M = len(self.z) if hasattr(self, 'z') else 0
        print("old N, M", old_N, old_M)

        start = time.time()
        print("init y, z, w", file=file, flush=True)

        self.model.remove(self.y)
        self.y = self.model.addVars(N, vtype=GRB.BINARY, name='y')

        # remove replication constr because db_table can change
        self.model.remove(self.rep_constr)
        self.rep_constr = []

        self.model.remove(self.z)
        self.model.remove(self.w)
        self.z = self.model.addVars(M, vtype=GRB.BINARY, name='z')
        self.w = self.model.addVars(M, vtype=GRB.BINARY, name='w')
        for j in range(M):
            self.model.addConstr(self.z[j] + self.w[j] <= 1, name=f'zw_{j}')

        if self.rep_threshold is not None:
            for key in self.rep_list:
                # if self.unique_db_tables[key] is not None:
                if key in self.unique_db_tables:
                    j = self.unique_db_tables[key]
                    self.rep_constr.append(self.model.addConstr(self.z[j] == 0, name=f'z_{j}_0'))
                    self.rep_constr.append(self.model.addConstr(self.w[j] == 0, name=f'w_{j}_0'))
                else:
                    print("Warning: replicated table not found in unique_db_tables", key)

        print_time(start, time.time(), "y, z, w created", file=file)

        start = time.time()
        print("If sample rate < 1, progress is usually under-estimated, nothing wrong happened")
        print("init u, v", file=file, flush=True)
        # Recover u, v complete
        self.model.remove(self.u)
        self.model.remove(self.v)
        self.u = {}
        self.v = {}

        replicated_indices = [self.unique_db_tables[key] for key in self.rep_list if key in self.unique_db_tables]
        for j in range(M):
            if j in replicated_indices:
                continue
            job_ids = self.adj_list_input[j].keys()
            for i in job_ids:
                self.u[(i, j)] = self.model.addVar(vtype=GRB.BINARY, name=f'u_{i}_{j}')
                self.v[(i, j)] = self.model.addVar(vtype=GRB.BINARY, name=f'v_{i}_{j}')

        print_time(start, time.time(), "u, v created", file=file)
        file.close()
        # Update model to integrate new variables
        self.model.update()
        print("updated to new N, M", len(self.y), len(self.z), "=", len(self.w))
        print("updated to new U, V", len(self.u), len(self.v))

    def add_workload_constr(self, f_print=sys.stdout):

        N = self.abFP_num
        M = self.dataset_num
        y, z, w, u, v = self.get_y_z_w_u_v()

        constrs = []

        """ zw[j] = z[j] * w[j] always = 0 """

        """ u[i, j] = (1-y[i])*(z[j]-z[j]*w[j)) = (1-y[i])*z[j]
        u[i, j] = 1 only when y[i] == 0 (job on-prem) and z[j] == 1 (z represents the state of data on-prem, 1 means not in on-prem)
        This should be egress for input data (or say read), ingress for output data (or say write)
        """
        pair_num = len(self.df)
        temp_counter = 0
        step = pair_num // 10
        start = time.time()
        print("u constr start", file=f_print, flush=True)

        replicated_indices = [self.unique_db_tables[key] for key in self.rep_list if key in self.unique_db_tables]

        for j in range(M):
            if j in replicated_indices:
                continue
            job_ids = self.adj_list_input[j].keys()
            for i in job_ids:
                temp_counter += 1
                constrs.append(self.model.addConstr(u[(i, j)] + y[i] <= 1, name=f'u1_{i}_{j}'))
                constrs.append(self.model.addConstr(u[(i, j)] - z[j] <= 0, name=f'u2_{i}_{j}'))
                constrs.append(self.model.addConstr(- y[i] + z[j] - u[(i, j)] <= 0, name=f'u3_{i}_{j}'))

                if step != 0 and temp_counter % step == 0:
                    print(f"== progress:{temp_counter / pair_num * 100:.0f}%", file=f_print, flush=True)
        print_time(start, time.time(), "u constr created", file=f_print)

        """ v[i, j] = y[i] * (w[j] - z[j] * w[j]) = y[i] * w[j]
        v[i, j] = 1 only when y[i] == 1 (job in cloud) and w[j] == 1 (w represents the state of data in cloud, 1 means not in cloud)
        This should be ingress for input data (or say read), egress for output data (or say write)
        """
        temp_counter = 0
        start = time.time()
        print("v constr start", file=f_print, flush=True)
        for j in range(M):
            if j in replicated_indices:
                continue
            job_ids = self.adj_list_input[j].keys()
            for i in job_ids:
            # for i in range(N):
            #     if self.input_matrix_gb[i, j] > 0 or self.output_matrix_gb[i, j] > 0:
                temp_counter += 1
                constrs.append(self.model.addConstr(v[(i, j)] - y[i] <= 0, name=f'v1_{i}_{j}'))
                constrs.append(self.model.addConstr(v[(i, j)] - w[j] <= 0, name=f'v2_{i}_{j}'))
                constrs.append(self.model.addConstr(y[i] + w[j] - v[(i, j)] <= 1, name=f'v3_{i}_{j}'))
                if step != 0 and temp_counter % step == 0:
                    print(f"== progress:{temp_counter / pair_num * 100:.0f}%", file=f_print, flush=True)
        print_time(start, time.time(), "v constr created", file=f_print)
        return constrs

    def get_y_z_w_u_v(self):
        return self.y, self.z, self.w, self.u, self.v

    def update_previous_placement(self, file_path):
        self.previous_placement_path = file_path
        # header: table,z,w,size
        prev_placement_df = pd.read_csv(file_path)
        prev_placement_df['table'] = prev_placement_df['table'].astype(str)

        prev_placement_df, _ = merge_similar_rows(prev_placement_df, self.yugong)
        self.prev_z = len(self.unique_db_tables) * [-1]
        self.prev_w = len(self.unique_db_tables) * [-1]
        for idx, row in prev_placement_df.iterrows():
            db_table = row['table']
            j = self.unique_db_tables[db_table]
            assert j is not None and j < len(
                prev_placement_df), f"idx {idx}, {db_table}: j={j} >= {len(prev_placement_df)}"
            self.prev_z[j] = int(row['z'])
            self.prev_w[j] = int(row['w'])

        db_tables_list = list(self.unique_db_tables.keys())

        count = 0
        print("New tables", len(self.unique_db_tables) - len(self.prev_z))

        """
        if self.yugong:
            self.df_table_size['project'] = self.df_table_size.apply(
                lambda row: self.ownership.get_table_ownership(f"{row['hive_database_name']}.{row['hive_table_name']}"), axis=1)
            missing_sizes = self.df_table_size[self.df_table_size.apply(
                lambda row: f"{row['hive_database_name']}.{row['hive_table_name']}" in extra_db_tables, axis=1)]
            grouped_sizes = missing_sizes.groupby('project')['dir_size'].sum()
            group_num = len(grouped_sizes)
            print("# of grouped untouched projects this time period", group_num, flush=True, end=' ')
            for project in grouped_sizes.index:
                j_string = f"{project}.group"
                self.ownership.add_table_ownership(j_string, "Root|" + project)
                assert self.ownership.get_table_ownership(j_string) == project, f"Ownership not set for {j_string}"
        """

        for j in range(len(prev_placement_df), len(self.unique_db_tables)):
            key = db_tables_list[j]
            db_name, table_name = key.split('.')

            if self.yugong:
                project = self.ownership.get_table_ownership(key)
                db_group = f"{project}.group"
            else:
                db_group = f"{db_name}.group"

            if db_group in self.unique_db_tables:
                idx = self.unique_db_tables[db_group]
                self.prev_z[j] = self.prev_z[idx]
                self.prev_w[j] = self.prev_w[idx]
                count += 1
            else:
                #print(f"Warning: {db_group} not found")
                self.prev_z[j] = 0
                self.prev_w[j] = 0

        # debug
        for i in range(len(self.prev_z)):
            if self.prev_z[i] == -1:
                print("Warning: prev_z not updated", i, end=' ')
                for key in self.unique_db_tables:
                    if self.unique_db_tables[key] == i:
                        print(key)

        print("Updated previous placement", count, "times from grouped dbs")

    def solve_gurobi(self, p_egress_gb, p2_gb, r_min, r_max, X, dir_path,
                     s_min=0.0, s_max=1.0, binary=True,
                     time_limit=60 * 60, alpha=1,
                     p_network_gb=0
                     ):

        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        self.model.setParam("LogFile", dir_path + "/gurobi.log")

        f_print = open(os.path.join(dir_path, 'log.txt'), 'w')

        print("----------------------------------------", file=f_print, flush=True)
        print("Inputs: p1, p2, p3, c_min, c_max, X, s_min, s_max, binary", file=f_print, flush=True)
        print(p_egress_gb, p2_gb, p_network_gb, r_min, r_max, X, s_min, s_max, binary, file=f_print, flush=True)
        print(f"rep_threshold {self.rep_threshold} ({human_readable_size(self.rep_threshold * self.total_storage_gb * 1024**3)})", file=f_print, flush=True)
        print("previous placement", self.previous_placement_path, file=f_print, flush=True)
        print("workload info", self.workload_print_info, file=f_print, flush=True)
        print("db_table_size path", self.db_table_size_path, file=f_print, flush=True)
        print(f"k={self.k:.3f}", file=f_print, flush=True)
        print("X", int(X), "updated to", int(X * self.X_scale), f"({self.X_scale:.3f})", file=f_print, flush=True)
        X = X * self.X_scale
        print("YuGong", self.yugong, file=f_print, flush=True)
        print("----------------------------------------", file=f_print, flush=True)

        self.model.setParam(GRB.Param.TimeLimit, time_limit)

        # Define the decision variables
        N = self.abFP_num
        M = self.dataset_num
        print("N, M", N, M, file=f_print, flush=True)

        y, z, w, u, v = self.get_y_z_w_u_v()
        var_constr = []

        # Constraints for computation
        total_computation = sum(self.c)

        start = time.time()
        comp_expr1 = LinExpr()
        for i in range(N):
            comp_expr1.add(y[i], self.c.iloc[i])
        self.model.addConstr(-comp_expr1 <= -r_min * total_computation, name='comp1')
        self.model.addConstr(comp_expr1 <= r_max * total_computation, name='comp2')
        print_time(start, time.time(), "r constraints created", file=f_print)

        # Constraints for local storage
        z_var_constr = []
        if s_max < 1:
            # total_storage = sum(self.s)
            start = time.time()
            comp_expr3 = LinExpr()
            for j in range(M):
                comp_expr3.add(z[j], self.s[j])
            z_var_constr.append(
                self.model.addConstr(-comp_expr3 <= -(1 - s_max) * self.total_storage_gb, name='local1'))
            print_time(start, time.time(), f"local storage constraint created: <="
                                           f" {human_readable_size(s_max * self.total_storage_gb * 1024 ** 3)}"
                                           f" ({s_max:.2f} * {human_readable_size(self.total_storage_gb * 1024 ** 3)})",
                       file=f_print)
        else:

            for j in range(M):
                z_var_constr.append(
                    self.model.addConstr(z[j] == 0, name=f'z_{j}'))  # z[j] = 0 since D1 local storage since no limit
            print("s_max >= 1, no local storage constraint", file=f_print, flush=True)
        if s_min > 0.0:
            start = time.time()
            comp_expr4 = LinExpr()
            for j in range(M):
                comp_expr4.add(z[j], self.s[j])
            self.model.addConstr(comp_expr4 <= (1 - s_min) * self.total_storage_gb, name='local2')
            print_time(start, time.time(), f"local storage constraint created: >="
                                           f" {human_readable_size(s_min * self.total_storage_gb * 1024 ** 3)}"
                                           f" ({s_min:.2f} * {human_readable_size(self.total_storage_gb * 1024 ** 3)})",
                       file=f_print)

        start = time.time()

        # This is network usage constraint, so include every piece of ingress and egress
        # that comes from query execution
        # data movement suggested by Moirai is not included here
        comp_expr2 = LinExpr()

        replicated_indices = [self.unique_db_tables[key] for key in self.rep_list if key in self.unique_db_tables]

        for j in range(M):
            if j in replicated_indices:
                continue
            job_input = self.adj_list_input[j]
            job_output = self.adj_list_output[j]
            for i, input_size, output_size in zip(job_input.keys(), job_input.values(), job_output.values()):
                if input_size > 0:
                    comp_expr2.add(v[(i, j)], input_size) # ingress
                    comp_expr2.add(u[(i, j)], input_size) # egress
                if output_size > 0:
                    comp_expr2.add(v[(i, j)], output_size) # egress
                    comp_expr2.add(u[(i, j)], output_size) # ingress

        self.model.addConstr(comp_expr2 <= X, name='comp3')  # bandwidth
        print_time(start, time.time(), "X constraint created", file=f_print)

        # Set objective function
        start = time.time()
        obj_expr = LinExpr()
        for j in range(M):
            obj_expr.add(p2_gb * self.s[j] * (1 - w[j]))  # cloud object storage cost
        if self.prev_z is not None:  # data movement
            assert self.prev_w is not None
            if alpha > 1:
                print("alpha > 1, data movement is disallowed", file=f_print, flush=True)

                # existing cloud only dataset should not be removed
                var_constr.append(
                    self.model.addConstr(sum(w[j] if self.prev_w[j] == 0 else 0 for j in range(M)) == 0, name='w_sum'))
                # or copied back to onprem
                var_constr.append(
                    self.model.addConstr(sum(1 - z[j]
                                             if self.prev_w[j] == 0 and self.prev_z[j] == 1 else 0
                                             for j in range(M)) == 0, name='z_sum'))
                # TODO: this might incur extra replication from onprem to cloud, or more remote data access
            else:
                print("alpha <= 1, data movement is allowed, but penalized", file=f_print, flush=True)
                print("TODO: fake ingress price", p_egress_gb, file=f_print, flush=True)
                for j in range(len(self.prev_z)):
                    if self.prev_z[j] > 0:
                        obj_expr.add(p_egress_gb * self.s[j] * (self.prev_z[j] - z[j]) * alpha)  # egress cost
                    if self.prev_w[j] > 0:
                        obj_expr.add(p_egress_gb * self.s[j] * (self.prev_w[j] - w[j]) * alpha)  # ingress cost (fake)
                """ Deprecated, but might be useful in the future"""
                # print("Ingress price", p_network_gb, file=f_print, flush=True)
                # for j in range(len(self.prev_z)):
                #     if self.prev_z[j] > 0:
                #         obj_expr.add((p_egress_gb+p_network_gb) * self.s[j] * (self.prev_z[j] - z[j]) * alpha)  # egress cost
                #     if self.prev_w[j] > 0:
                #         obj_expr.add(p_network_gb * self.s[j] * (self.prev_w[j] - w[j]) * alpha)  # ingress cost (fake)

        for j in range(M):
            if j in replicated_indices:
                continue
            job_input = self.adj_list_input[j]
            job_output = self.adj_list_output[j]
            for i, input_size, output_size in zip(job_input.keys(), job_input.values(), job_output.values()):
                if input_size > 0:
                    obj_expr.add(u[(i, j)], (p_egress_gb+p_network_gb) * input_size) # egress
                    obj_expr.add(v[(i, j)], p_network_gb * input_size) # ingress
                if output_size > 0:
                    obj_expr.add(u[(i, j)], p_network_gb * output_size) # ingress
                    obj_expr.add(v[(i, j)], (p_egress_gb+p_network_gb) * output_size) # egress
            # for i in range(N):
            #     if self.input_matrix_gb[i, j] > 0:
            #         obj_expr.add(u[(i, j)], (p_egress_gb+p_network_gb) * self.input_matrix_gb[i, j])
            #         obj_expr.add(v[(i, j)], p_network_gb * self.input_matrix_gb[i, j])
            #     if self.output_matrix_gb[i, j] > 0:
            #         obj_expr.add(u[(i, j)], p_network_gb * self.output_matrix_gb[i, j])
            #         obj_expr.add(v[(i, j)], (p_egress_gb+p_network_gb) * self.output_matrix_gb[i, j])

        self.model.setObjective(obj_expr, GRB.MINIMIZE)
        print_time(start, time.time(), "obj created", file=f_print)

        self.model.update()

        print('----------------------------------------', file=f_print, flush=True)
        # Optimize
        start = time.time()
        self.model.optimize()
        print_time(start, time.time(), "model solved", file=f_print)

        # Print solution
        print("model status", self.model.status, file=f_print, flush=True)
        if self.model.status == GRB.OPTIMAL or self.model.status == GRB.TIME_LIMIT or self.model.status == GRB.INTERRUPTED:
            if self.model.status == GRB.OPTIMAL:
                print('Optimal solution found', file=f_print, flush=True)
            else:
                print('Suboptimal solution found', file=f_print, flush=True)
            print('Optimal co$t: %g' % self.model.objVal, end=' ', file=f_print, flush=True)
            print(f'({self.model.Runtime:.2f} seconds', end=' ', file=f_print, flush=True)
            # Print number of iterations
            iterations = self.model.IterCount
            print(f"in {iterations} iterations)", file=f_print, flush=True)

            # Print number of constraints
            num_constraints = len(self.model.getConstrs())
            print(f"Under # constraints: {num_constraints}", file=f_print, flush=True)
            print(f"Computation: {sum(y[i].x * self.c.iloc[i] for i in range(N)):.0f}", end=' ', file=f_print,
                  flush=True)
            # Calculate and print target computation (assuming you have a variable or method to get 'total_computation')
            min_target = r_min * total_computation
            max_target = r_max * total_computation
            print(f"∈ [{min_target:.0f} ({r_min}), "
                  f"{max_target:.0f} ({r_max})]", file=f_print, flush=True)

            local_storage = sum((1 - z[j].x) * self.s[j] for j in range(M))
            remote_storage = sum((1 - w[j].x) * self.s[j] for j in range(M))
            replication = sum((1 - z[j].x - w[j].x) * self.s[j] for j in range(M))
            total = local_storage + remote_storage - replication
            print(f"Storage: {human_readable_size(total * 1024 ** 3)} total "
                  # f"== {human_readable_size(self.total_storage_gb * 1024 ** 3)} total from the dataset file, "
                  f"{human_readable_size(local_storage * 1024 ** 3)} on-prem "
                  f" ∈ [{human_readable_size(s_min * total * 1024 ** 3)} ({s_min}), "
                  f"{human_readable_size(s_max * total * 1024 ** 3)} ({s_max})], "
                  f"{human_readable_size(remote_storage * 1024 ** 3)} GCP "
                  f"($ {remote_storage * p2_gb:.0f} cost), "
                  f"{human_readable_size(replication * 1024 ** 3)} overlap", file=f_print, flush=True)

            ingress_gb, egress_gb = 0, 0
            for j in range(M):
                if j in replicated_indices:
                    continue
                job_input = self.adj_list_input[j]
                job_output = self.adj_list_output[j]
                for i, input_size, output_size in zip(job_input.keys(), job_input.values(), job_output.values()):
                    if v[(i, j)].x > 0:
                        ingress_gb += input_size
                        egress_gb += output_size

            for j in range(M):
                if j in replicated_indices:
                    continue
                job_input = self.adj_list_input[j]
                job_output = self.adj_list_output[j]
                for i, input_size, output_size in zip(job_input.keys(), job_input.values(), job_output.values()):
                    if u[(i, j)].x > 0:
                        egress_gb += input_size
                        ingress_gb += output_size
            print(f"Ingress {human_readable_size(ingress_gb * 1024 ** 3)} "
                  f"< {human_readable_size(X * 1024 ** 3)}", file=f_print, flush=True)
            print(f'Egress {human_readable_size(egress_gb * 1024 ** 3)} '
                  f'{egress_gb * p_egress_gb:.2f} $', file=f_print, flush=True)
            print(f"Network {human_readable_size((ingress_gb + egress_gb) * 1024 ** 3)} "
                  f"{(ingress_gb + egress_gb) * p_network_gb:.2f} $ (Estimated)", file=f_print, flush=True)

            # consider data movement
            if self.prev_z is not None:
                assert self.prev_w is not None

                egress_movement = 0
                for j in range(len(self.prev_z)):
                    if self.prev_z[j] > 0:
                        egress_movement += (self.prev_z[j] - z[j].x) * self.s[j]
                        # debug
                        if self.prev_z[j] - z[j].x < 0:
                            print(f"Warning: egress {j} {self.prev_z[j]} - {z[j].x} * {self.s[j]} = "
                                  f"{human_readable_size(self.prev_z[j] - z[j].x) * self.s[j] * 1024 ** 3}",
                                  file=f_print, flush=True)

                # egress_movement = sum((self.prev_z[j] - z[j].x) * self.s[j]
                #                        if self.prev_z[j] > 0 else 0
                #                        for j in range(len(self.prev_z)))
                ingress_movement = sum((self.prev_w[j] - w[j].x) * self.s[j]
                                       if self.prev_w[j] > 0 else 0
                                       for j in range(len(self.prev_w)))

                # debug, print the ones that are moved
                # for table in self.unique_db_tables:
                #     j = self.unique_db_tables[table]
                #     if self.prev_w[j] > 0 and w[j].x == 0:
                #         print(f"Warning: ingress {table} {j} {self.prev_w[j]} - {w[j].x} * {self.s[j]} = "
                #               f"{human_readable_size((self.prev_w[j] - w[j].x) * self.s[j] * 1024 ** 3)}",
                #               file=f_print, flush=True)

                print(f"Data movement: {human_readable_size(ingress_movement * 1024 ** 3)} ingress, "
                      f"{human_readable_size(egress_movement * 1024 ** 3)} "
                      f"{egress_movement * p_egress_gb:.2f}$ egress, "
                      f"{human_readable_size((ingress_movement + egress_movement) * 1024 ** 3)} network "
                      f"{(ingress_movement + egress_movement) * p_network_gb:.2f}$ (Estimated)", file=f_print, flush=True)

            # dump z, w, y to file
            with open(dir_path + '/dataset_placement.csv', 'w') as f:
                # header
                f.write('table,z,w,size\n')
                for key in self.unique_db_tables:
                    j = self.unique_db_tables[key]
                    f.write(f'{key},{z[j].x},{w[j].x},{self.s[j]}\n')
            with open(dir_path + '/query_placement.csv', 'w') as f:
                # header
                f.write('abFP,y\n')
                for key in self.unique_abFP:
                    f.write(f'{key},{y[self.unique_abFP[key]].x}\n')

        f_print.close()

        # clean up contraints
        print("before clean up", len(self.model.getConstrs()))
        self.model.remove(var_constr)
        self.model.remove(self.model.getConstrByName('comp1'))
        self.model.remove(self.model.getConstrByName('comp2'))
        self.model.remove(self.model.getConstrByName('comp3'))
        self.model.remove(z_var_constr)
        if self.model.getConstrByName('local2') is not None:
            self.model.remove(self.model.getConstrByName('local2'))
        self.model.update()
        print("after clean up", len(self.model.getConstrs()))
        print("\n")
