def _get_sampled_indexes_and_baseline_df()

in misc/CCSynth/CC/DataInsights/src/prose/datainsights/_assertion/_assertion_helper.py [0:0]


    def _get_sampled_indexes_and_baseline_df(self, violation_threshold, sample_only=True):
        sampled_indexes = []
        baseline_df = None

        for k in reversed(self.assertions.constrained_invariants):
            cur_df = pd.DataFrame(
                self.row_wise_violation_summary.loc[
                    list(k.constraint.apply(self.test_df).index)
                ]["violation"]
            ).sort_values(by=["violation"], ascending=False)

            valid = False
            if cur_df.empty:
                continue

            if sample_only:
                # Pick a random sample as a representative row from each decision tree partition
                for idx in list(
                    cur_df[cur_df["violation"] == max(cur_df["violation"])]
                    .sample(frac=1)
                    .index
                ):
                    if cur_df.loc[idx]["violation"] < violation_threshold:
                        break
                    if idx not in sampled_indexes:
                        sampled_indexes.append(idx)
                        valid = True
                        break

                if not valid:
                    continue
            else:
                newly_added = 0
                for i in cur_df.index:
                    if i not in sampled_indexes:
                        sampled_indexes.append(i)
                        newly_added += 1


            cur_train_df_numeric = k.constraint.apply(self.train_df)._get_numeric_data()
            cur_train_df_categorical = k.constraint.apply(
                self.train_df, drop_column=False
            )[
                [
                    col
                    for col in self.train_df.columns
                    if col not in cur_train_df_numeric.columns
                ]
            ]

            # Compute a baseline dataframe to visually contrast with the representative violating rows
            current_mean_row = pd.DataFrame(
                np.array(cur_train_df_numeric.mean()).reshape(
                    (1, len(cur_train_df_numeric.columns))
                ),
                columns=cur_train_df_numeric.columns,
            )

            if len(cur_train_df_categorical.columns) > 0:
                current_mode_row = pd.DataFrame(
                    np.array(cur_train_df_categorical.mode()).reshape(
                        (-1, len(cur_train_df_categorical.columns))
                    ),
                    columns=cur_train_df_categorical.columns,
                )[:1]

            cur_row = pd.DataFrame(columns=self.train_df.columns)

            for col in self.train_df.columns:
                if col in cur_train_df_numeric.columns:
                    cur_row[col] = current_mean_row[col]
                else:
                    cur_row[col] = current_mode_row[col]

            if not sample_only:
                cur_row = pd.DataFrame(np.tile(np.array(cur_row), (newly_added, 1)),
                                       columns=self.train_df.columns)
                for col in self.train_df.columns:
                    if col in cur_train_df_numeric.columns:
                        cur_row[col] = cur_row[col].apply(float)
            if baseline_df is None:
                baseline_df = cur_row
            else:
                baseline_df = pd.concat([baseline_df, cur_row], ignore_index=True)

        return sampled_indexes, baseline_df