def prepare_workload()

in optimizer.py [0:0]


    def prepare_workload(self):
        """ clean_df() deprecated
        def clean_df():
            # remove lines in self.df where db.table has no size

            print("\ncleaning df start...")
            # Create a set of valid db.table combinations from df_table_size
            valid_db_tables = set(
                self.df_table_size.apply(lambda row: f"{row['hive_database_name']}.{row['hive_table_name']}", axis=1))

            initial_num_rows = len(self.df)
            initial_total_input_size = self.df['inputDataSize'].sum()
            initial_total_cputime = self.df['cputime'].sum()

            removed_rows = []

            for index, row in self.df.iterrows():
                db_table = f"{row['db_name']}.{row['table_name']}"
                if db_table not in valid_db_tables:
                    removed_rows.append(row)

            # Remove the rows from the dataframe
            self.df = self.df.drop([row.name for row in removed_rows])

            # Calculate remaining
            remaining_total_input_size = self.df['inputDataSize'].sum()
            remaining_total_cputime = self.df['cputime'].sum()

            print(f"# of lines removed: {len(removed_rows)} vs {initial_num_rows}")
            # print(
            #     f"# of unique tables involved in removal: {len(set(f'{row['db_name']}.{row['table_name']}' for row in removed_rows))}")
            print(f"Total CPU time influenced: {initial_total_cputime - remaining_total_cputime:.0f}"
                  f" ({(initial_total_cputime - remaining_total_cputime)/initial_total_cputime*100:.1f}%)")
            print(f"Total inputDataSize influenced: {human_readable_size(initial_total_input_size-remaining_total_input_size)}"
                  f" vs {human_readable_size(initial_total_input_size)}")

            self.abFP_num = self.df['abstractFingerPrint'].nunique()
            self.db_table_num = self.df.groupby(['db_name', 'table_name']).ngroups
            print("[After clean] # abFP", self.abFP_num)
            print("[After clean] # db_table in workload", self.db_table_num)

            self.c = self.df.groupby('abstractFingerPrint')['cputime'].sum()
            print("[After clean] compute c", self.c)

            print("cleaning df end...\n", flush=True) """

        def prepare_unique_abFP():
            counter_a = len(self.unique_abFP)  # old
            counter_t = len(self.unique_db_tables)  # old
            print(f"from {counter_a} x {counter_t}", end="")
            for index, row in self.df.iterrows():
                i_string = row['abstractFingerPrint']
                j_string = f"{row['db_name']}.{row['table_name']}"
                if i_string not in self.unique_abFP:
                    self.unique_abFP[i_string] = counter_a
                    counter_a += 1
                if j_string not in self.unique_db_tables:
                    self.unique_db_tables[j_string] = counter_t
                    counter_t += 1
            assert counter_a == len(self.unique_abFP) and counter_t == len(self.unique_db_tables)
            print(f" to {counter_a} x {counter_t}")

            self.abFP_num = len(self.unique_abFP)
            self.db_table_num = len(self.unique_db_tables)

            self.adj_list_input = defaultdict(dict)
            self.adj_list_output = defaultdict(dict)

            for index, row in self.df.iterrows():
                i_string = row['abstractFingerPrint']
                j_string = f"{row['db_name']}.{row['table_name']}"  # This needs proper mapping
                i = self.unique_abFP[i_string]
                j = self.unique_db_tables[j_string]
                self.adj_list_input[j][i] = row['inputDataSize'] / 1024 ** 3  # convert to GB
                self.adj_list_output[j][i] = row['outputDataSize'] / 1024 ** 3

        # if clean:
        #     clean_df()

        prepare_unique_abFP()

        counter_t = len(self.unique_db_tables)
        pair_num = len(self.df)
        print("[sanity check] # of non-zero edges [i,j]", pair_num, flush=True)
        # print("should no more than # of rows in self.df", len(self.df), flush=True)

        # expand s variable to all db_table
        assert self.df_table_size is not None
        # Identify the db_tables present in self.df_table_size
        db_tables_in_df = set(
            self.df_table_size.apply(lambda row: f"{row['hive_database_name']}.{row['hive_table_name']}", axis=1))

        # Find the db_tables that are in self.df_table_size but not in self.unique_db_tables
        extra_db_tables = db_tables_in_df - set(self.unique_db_tables.keys())

        if self.yugong:
            self.df_table_size['project'] = self.df_table_size.apply(
                lambda row: self.ownership.get_table_ownership(f"{row['hive_database_name']}.{row['hive_table_name']}"), axis=1)
            missing_sizes = self.df_table_size[self.df_table_size.apply(
                lambda row: f"{row['hive_database_name']}.{row['hive_table_name']}" in extra_db_tables, axis=1)]
            grouped_sizes = missing_sizes.groupby('project')['dir_size'].sum()
            group_num = len(grouped_sizes)
            print("# of grouped untouched projects this time period", group_num, flush=True, end=' ')
            for project in grouped_sizes.index:
                j_string = f"{project}.group"
                self.ownership.add_table_ownership(j_string, project)
                assert self.ownership.get_table_ownership(j_string) == project, f"Ownership not set for {j_string}"
        else:
            # Filter self.df_table_size to keep only the extra db_tables, then group by hive_database_name and sum the sizes
            missing_sizes = self.df_table_size[self.df_table_size.apply(
                lambda row: f"{row['hive_database_name']}.{row['hive_table_name']}" in extra_db_tables, axis=1)]
            grouped_sizes = missing_sizes.groupby('hive_database_name')['dir_size'].sum()
            group_num = len(grouped_sizes)
            print("# of grouped untouched dbs this time period", group_num, flush=True, end=' ')

        count_bf = len(self.unique_db_tables)
        # allocate id for untouched dbs
        for db_name in grouped_sizes.index:
            j_string = f"{db_name}.group"
            if j_string not in self.unique_db_tables:
                self.unique_db_tables[j_string] = counter_t
                counter_t += 1
        print("(only", counter_t - count_bf, "newly appeared dbs)", flush=True)

        # self.dataset_num = self.db_table_num + group_num
        assert counter_t == len(self.unique_db_tables), f"{counter_t} != {len(self.unique_db_tables)}"
        self.dataset_num = counter_t
        for table_id in range(self.db_table_num, self.dataset_num):
            self.adj_list_input[table_id] = {}
            self.adj_list_output[table_id] = {}
        print(f"adjacency lists have tables {len(self.adj_list_input)}", flush=True)

        self.s = np.zeros(self.dataset_num)
        start = time.time()

        size_lookup = {
            (row.hive_database_name, row.hive_table_name): row.dir_size
            for row in self.df_table_size.itertuples(index=False)
        }

        sum_hot_gb = 0
        # set_size = self.df_table_size['dir_size'].min()
        set_size = 0
        for db_table in self.unique_db_tables:
            parts = db_table.split('.')
            if len(parts) != 2:
                raise ValueError(f"Invalid db_table format: {db_table} into {parts}")
            db_name, table_name = parts
            size = size_lookup.get((db_name, table_name), None)
            if size is None or size == 0:
                if db_name in grouped_sizes.index:
                    continue  # delayed to next code block
                # print("Warning: db_table not found in table size file", db_table, "set to", set_size)
                self.s[self.unique_db_tables[db_table]] = set_size
                sum_hot_gb += set_size
            else:
                self.s[self.unique_db_tables[db_table]] = size
                sum_hot_gb += size
        print(f"Touched data size: {human_readable_size(sum_hot_gb * 1024 ** 3)}")

        # cold dataset that is not in the workload
        sum_gb = 0
        for db_name in map(str, grouped_sizes.index):
            j_string = f"{db_name}.group"
            self.s[self.unique_db_tables[j_string]] = grouped_sizes[db_name]
            sum_gb += grouped_sizes[db_name]
        print(f"Non-touched data size: {human_readable_size(sum_gb * 1024 ** 3)}")
        print_time(start, time.time(), "s created")