def train()

in bugbug/model.py [0:0]


    def train(self, importance_cutoff=0.15, limit=None):
        classes, self.class_names = self.get_labels()
        self.class_names = sort_class_names(self.class_names)

        # Get items and labels, filtering out those for which we have no labels.
        X_gen, y = split_tuple_generator(lambda: self.items_gen(classes))

        # Extract features from the items.
        X = self.extraction_pipeline.transform(X_gen)

        # Calculate labels.
        y = np.array(y)
        self.le.fit(y)

        if limit:
            X = X[:limit]
            y = y[:limit]

        logger.info(f"X: {X.shape}, y: {y.shape}")

        is_multilabel = isinstance(y[0], np.ndarray)
        is_binary = len(self.class_names) == 2

        # Split dataset in training and test.
        X_train, X_test, y_train, y_test = self.train_test_split(X, y)

        tracking_metrics = {}

        # Use k-fold cross validation to evaluate results.
        if self.cross_validation_enabled:
            scorings = ["accuracy"]
            if len(self.class_names) == 2:
                scorings += ["precision", "recall"]

            scores = cross_validate(
                self.clf, X_train, self.le.transform(y_train), scoring=scorings, cv=5
            )

            logger.info("Cross Validation scores:")
            for scoring in scorings:
                score = scores[f"test_{scoring}"]
                tracking_metrics[f"test_{scoring}"] = {
                    "mean": score.mean(),
                    "std": score.std() * 2,
                }
                logger.info(
                    f"{scoring.capitalize()}: f{score.mean()} (+/- {score.std() * 2})"
                )

        logger.info(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
        logger.info(f"X_test: {X_test.shape}, y_test: {y_test.shape}")

        self.clf.fit(X_train, self.le.transform(y_train))
        logger.info("Number of features: %d", self.clf.steps[-1][1].n_features_in_)

        logger.info("Model trained")

        feature_names = self.get_human_readable_feature_names()
        if self.calculate_importance and len(feature_names):
            explainer = shap.TreeExplainer(self.clf.named_steps["estimator"])
            _X_train = get_transformer_pipeline(self.clf).transform(X_train)
            shap_values = explainer.shap_values(_X_train)

            # In the binary case, sometimes shap returns a single shap values matrix.
            if is_binary and not isinstance(shap_values, list):
                shap_values = [-shap_values, shap_values]
                summary_plot_value = shap_values[1]
                summary_plot_type = "layered_violin"
            else:
                summary_plot_value = shap_values
                summary_plot_type = None

            shap.summary_plot(
                summary_plot_value,
                to_array(_X_train),
                feature_names=feature_names,
                class_names=self.class_names,
                plot_type=summary_plot_type,
                show=False,
            )

            matplotlib.pyplot.savefig("feature_importance.png", bbox_inches="tight")
            matplotlib.pyplot.xlabel("Impact on model output")
            matplotlib.pyplot.clf()

            important_features = self.get_important_features(
                importance_cutoff, shap_values
            )

            self.print_feature_importances(important_features)

            # Save the important features in the metric report too
            feature_report = self.save_feature_importances(
                important_features, feature_names
            )

            tracking_metrics["feature_report"] = feature_report

        logger.info("Training Set scores:")
        y_pred = self.clf.predict(X_train)
        y_pred = self.le.inverse_transform(y_pred)
        if not is_multilabel:
            print(
                classification_report_imbalanced(
                    y_train, y_pred, labels=self.class_names
                )
            )

        logger.info("Test Set scores:")
        # Evaluate results on the test set.
        y_pred = self.clf.predict(X_test)
        y_pred = self.le.inverse_transform(y_pred)

        if is_multilabel:
            assert isinstance(y_pred[0], np.ndarray), (
                "The predictions should be multilabel"
            )

        logger.info(f"No confidence threshold - {len(y_test)} classified")
        if is_multilabel:
            confusion_matrix = metrics.multilabel_confusion_matrix(y_test, y_pred)
        else:
            confusion_matrix = metrics.confusion_matrix(
                y_test, y_pred, labels=self.class_names
            )

            print(
                classification_report_imbalanced(
                    y_test, y_pred, labels=self.class_names
                )
            )
            report = classification_report_imbalanced_values(
                y_test, y_pred, labels=self.class_names
            )

            tracking_metrics["report"] = report

        print_labeled_confusion_matrix(
            confusion_matrix, self.class_names, is_multilabel=is_multilabel
        )

        tracking_metrics["confusion_matrix"] = confusion_matrix.tolist()

        confidence_thresholds = [0.6, 0.7, 0.8, 0.9]

        if is_binary:
            confidence_thresholds = [0.1, 0.2, 0.3, 0.4] + confidence_thresholds

        # Evaluate results on the test set for some confidence thresholds.
        for confidence_threshold in confidence_thresholds:
            y_pred_probas = self.clf.predict_proba(X_test)
            confidence_class_names = self.class_names + ["__NOT_CLASSIFIED__"]

            y_pred_filter = []
            classified_indices = []
            for i in range(0, len(y_test)):
                if not is_binary:
                    argmax = np.argmax(y_pred_probas[i])
                else:
                    argmax = 1 if y_pred_probas[i][1] > confidence_threshold else 0

                if y_pred_probas[i][argmax] < confidence_threshold:
                    if not is_multilabel:
                        y_pred_filter.append("__NOT_CLASSIFIED__")
                    continue

                classified_indices.append(i)
                if is_multilabel:
                    y_pred_filter.append(y_pred[i])
                else:
                    y_pred_filter.append(argmax)

            if not is_multilabel:
                y_pred_filter = np.array(y_pred_filter)
                y_pred_filter[classified_indices] = self.le.inverse_transform(
                    np.array(y_pred_filter[classified_indices], dtype=int)
                )

            if is_multilabel:
                classified_num = len(classified_indices)
            else:
                classified_num = sum(
                    1 for v in y_pred_filter if v != "__NOT_CLASSIFIED__"
                )

            logger.info(
                f"\nConfidence threshold > {confidence_threshold} - {classified_num} classified"
            )
            if is_multilabel:
                confusion_matrix = metrics.multilabel_confusion_matrix(
                    y_test[classified_indices], np.asarray(y_pred_filter)
                )
            else:
                confusion_matrix = metrics.confusion_matrix(
                    y_test.astype(str),
                    y_pred_filter.astype(str),
                    labels=confidence_class_names,
                )
                print(
                    classification_report_imbalanced(
                        y_test.astype(str),
                        y_pred_filter.astype(str),
                        labels=confidence_class_names,
                    )
                )
            print_labeled_confusion_matrix(
                confusion_matrix, confidence_class_names, is_multilabel=is_multilabel
            )

        self.evaluation()

        if self.entire_dataset_training:
            logger.info("Retraining on the entire dataset...")

            X_train = X
            y_train = y

            logger.info(f"X_train: {X_train.shape}, y_train: {y_train.shape}")

            self.clf.fit(X_train, self.le.transform(y_train))

        model_directory = self.__class__.__name__.lower()
        makedirs(model_directory, exist_ok=True)

        step_name, estimator = self.clf.steps.pop()
        if issubclass(type(estimator), XGBModel):
            xgboost_model_path = path.join(model_directory, "xgboost.ubj")
            estimator.save_model(xgboost_model_path)

            # Since we save the estimator separately, we need to reset it to
            # prevent its data from being pickled with the pipeline.
            hyperparameters = estimator.get_params()
            estimator = estimator.__class__(**hyperparameters)
        self.clf.steps.append((step_name, estimator))

        model_path = path.join(model_directory, "model.pkl")
        with open(model_path, "wb") as f:
            pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)

        if self.store_dataset:
            with open(f"{self.__class__.__name__.lower()}_data_X", "wb") as f:
                pickle.dump(X, f, protocol=pickle.HIGHEST_PROTOCOL)

            with open(f"{self.__class__.__name__.lower()}_data_y", "wb") as f:
                pickle.dump(y, f, protocol=pickle.HIGHEST_PROTOCOL)

        return tracking_metrics