def generate_feature_importance_data()

in scripts/commit_classifier.py [0:0]


    def generate_feature_importance_data(self, probs, importance):
        _X = get_transformer_pipeline(self.clf).transform(self.X)
        X_shap_values = shap.TreeExplainer(
            self.clf.named_steps["estimator"]
        ).shap_values(_X)

        pred_class = self.model.le.inverse_transform([probs[0].argmax()])[0]

        features = []
        for i, (val, feature_index, is_positive) in enumerate(
            importance["importances"]["classes"][pred_class][0]
        ):
            name = importance["feature_legend"][str(i + 1)]
            value = importance["importances"]["values"][0, int(feature_index)]

            shap.summary_plot(
                X_shap_values[:, int(feature_index)].reshape(_X.shape[0], 1),
                _X[:, int(feature_index)].reshape(_X.shape[0], 1),
                feature_names=[""],
                plot_type="layered_violin",
                show=False,
            )
            matplotlib.pyplot.xlabel("Impact on model output")
            img = io.BytesIO()
            matplotlib.pyplot.savefig(img, bbox_inches="tight")
            matplotlib.pyplot.clf()
            img.seek(0)
            base64_img = base64.b64encode(img.read()).decode("ascii")

            X = _X[:, int(feature_index)]
            y = self.y[X != 0]
            X = X[X != 0]
            spearman = spearmanr(X, y)

            buggy_X = X[y == 1]
            clean_X = X[y == 0]
            median = np.median(X)
            median_clean = np.median(clean_X)
            median_buggy = np.median(buggy_X)

            perc_buggy_values_higher_than_median = (
                buggy_X >= median
            ).sum() / buggy_X.shape[0]
            perc_buggy_values_lower_than_median = (
                buggy_X < median
            ).sum() / buggy_X.shape[0]
            perc_clean_values_higher_than_median = (
                clean_X > median
            ).sum() / clean_X.shape[0]
            perc_clean_values_lower_than_median = (
                clean_X <= median
            ).sum() / clean_X.shape[0]

            logger.info("Feature: {}".format(name))
            logger.info("Shap value: {}{}".format("+" if (is_positive) else "-", val))
            logger.info("spearman: %f", spearman)
            logger.info("value: %f", value)
            logger.info("overall mean: %f", np.mean(X))
            logger.info("overall median: %f", np.median(X))
            logger.info("mean for y == 0: %f", np.mean(clean_X))
            logger.info("mean for y == 1: %f", np.mean(buggy_X))
            logger.info("median for y == 0: %f", np.median(clean_X))
            logger.info("median for y == 1: %f", np.median(buggy_X))

            logger.info(
                "perc_buggy_values_higher_than_median: %f",
                perc_buggy_values_higher_than_median,
            )
            logger.info(
                "perc_buggy_values_lower_than_median: %f",
                perc_buggy_values_lower_than_median,
            )
            logger.info(
                "perc_clean_values_higher_than_median: %f",
                perc_clean_values_higher_than_median,
            )
            logger.info(
                "perc_clean_values_lower_than_median: %f",
                perc_clean_values_lower_than_median,
            )

            features.append(
                {
                    "index": i + 1,
                    "name": name,
                    "shap": float(f"{'+' if (is_positive) else '-'}{val}"),
                    "value": importance["importances"]["values"][0, int(feature_index)],
                    "spearman": spearman,
                    "median": median,
                    "median_bug_introducing": median_buggy,
                    "median_clean": median_clean,
                    "perc_buggy_values_higher_than_median": perc_buggy_values_higher_than_median,
                    "perc_buggy_values_lower_than_median": perc_buggy_values_lower_than_median,
                    "perc_clean_values_higher_than_median": perc_clean_values_higher_than_median,
                    "perc_clean_values_lower_than_median": perc_clean_values_lower_than_median,
                    "plot": base64_img,
                }
            )

        # Group together features that are very similar to each other, so we can simplify the explanation
        # to users.
        attributes = ["Total", "Maximum", "Minimum", "Average"]
        already_added = set()
        feature_groups = []
        for i1, f1 in enumerate(features):
            if i1 in already_added:
                continue

            feature_groups.append([f1])

            for j, f2 in enumerate(features[i1 + 1 :]):
                i2 = j + i1 + 1

                f1_name = f1["name"]
                for attribute in attributes:
                    if f1_name.startswith(attribute):
                        f1_name = f1_name[len(attribute) + 1 :]
                        break

                f2_name = f2["name"]
                for attribute in attributes:
                    if f2_name.startswith(attribute):
                        f2_name = f2_name[len(attribute) + 1 :]
                        break

                if f1_name != f2_name:
                    continue

                already_added.add(i2)
                feature_groups[-1].append(f2)

        # Pick a representative example from each group.
        features = []
        for feature_group in feature_groups:
            shap_sum = sum(f["shap"] for f in feature_group)

            # Only select easily explainable features from the group.
            selected = [
                f
                for f in feature_group
                if (
                    f["shap"] > 0
                    and abs(f["value"] - f["median_bug_introducing"])
                    < abs(f["value"] - f["median_clean"])
                )
                or (
                    f["shap"] < 0
                    and abs(f["value"] - f["median_clean"])
                    < abs(f["value"] - f["median_bug_introducing"])
                )
            ]

            # If there are no easily explainable features in the group, select all features of the group.
            if len(selected) == 0:
                selected = feature_group

            def feature_sort_key(f):
                if f["shap"] > 0 and f["spearman"][0] > 0:
                    return f["perc_buggy_values_higher_than_median"]
                elif f["shap"] > 0 and f["spearman"][0] < 0:
                    return f["perc_buggy_values_lower_than_median"]
                elif f["shap"] < 0 and f["spearman"][0] > 0:
                    return f["perc_clean_values_lower_than_median"]
                elif f["shap"] < 0 and f["spearman"][0] < 0:
                    return f["perc_clean_values_higher_than_median"]

            feature = max(selected, key=feature_sort_key)
            feature["shap"] = shap_sum

            for attribute in attributes:
                if feature["name"].startswith(attribute):
                    feature["name"] = feature["name"][len(attribute) + 1 :].capitalize()
                    break

            features.append(feature)

        with open("importances.json", "w") as f:
            json.dump(features, f)