in bugbug/model.py [0:0]
def train(self, importance_cutoff=0.15, limit=None):
classes, self.class_names = self.get_labels()
self.class_names = sort_class_names(self.class_names)
# Get items and labels, filtering out those for which we have no labels.
X_gen, y = split_tuple_generator(lambda: self.items_gen(classes))
# Extract features from the items.
X = self.extraction_pipeline.transform(X_gen)
# Calculate labels.
y = np.array(y)
self.le.fit(y)
if limit:
X = X[:limit]
y = y[:limit]
logger.info(f"X: {X.shape}, y: {y.shape}")
is_multilabel = isinstance(y[0], np.ndarray)
is_binary = len(self.class_names) == 2
# Split dataset in training and test.
X_train, X_test, y_train, y_test = self.train_test_split(X, y)
tracking_metrics = {}
# Use k-fold cross validation to evaluate results.
if self.cross_validation_enabled:
scorings = ["accuracy"]
if len(self.class_names) == 2:
scorings += ["precision", "recall"]
scores = cross_validate(
self.clf, X_train, self.le.transform(y_train), scoring=scorings, cv=5
)
logger.info("Cross Validation scores:")
for scoring in scorings:
score = scores[f"test_{scoring}"]
tracking_metrics[f"test_{scoring}"] = {
"mean": score.mean(),
"std": score.std() * 2,
}
logger.info(
f"{scoring.capitalize()}: f{score.mean()} (+/- {score.std() * 2})"
)
logger.info(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
logger.info(f"X_test: {X_test.shape}, y_test: {y_test.shape}")
self.clf.fit(X_train, self.le.transform(y_train))
logger.info("Number of features: %d", self.clf.steps[-1][1].n_features_in_)
logger.info("Model trained")
feature_names = self.get_human_readable_feature_names()
if self.calculate_importance and len(feature_names):
explainer = shap.TreeExplainer(self.clf.named_steps["estimator"])
_X_train = get_transformer_pipeline(self.clf).transform(X_train)
shap_values = explainer.shap_values(_X_train)
# In the binary case, sometimes shap returns a single shap values matrix.
if is_binary and not isinstance(shap_values, list):
shap_values = [-shap_values, shap_values]
summary_plot_value = shap_values[1]
summary_plot_type = "layered_violin"
else:
summary_plot_value = shap_values
summary_plot_type = None
shap.summary_plot(
summary_plot_value,
to_array(_X_train),
feature_names=feature_names,
class_names=self.class_names,
plot_type=summary_plot_type,
show=False,
)
matplotlib.pyplot.savefig("feature_importance.png", bbox_inches="tight")
matplotlib.pyplot.xlabel("Impact on model output")
matplotlib.pyplot.clf()
important_features = self.get_important_features(
importance_cutoff, shap_values
)
self.print_feature_importances(important_features)
# Save the important features in the metric report too
feature_report = self.save_feature_importances(
important_features, feature_names
)
tracking_metrics["feature_report"] = feature_report
logger.info("Training Set scores:")
y_pred = self.clf.predict(X_train)
y_pred = self.le.inverse_transform(y_pred)
if not is_multilabel:
print(
classification_report_imbalanced(
y_train, y_pred, labels=self.class_names
)
)
logger.info("Test Set scores:")
# Evaluate results on the test set.
y_pred = self.clf.predict(X_test)
y_pred = self.le.inverse_transform(y_pred)
if is_multilabel:
assert isinstance(y_pred[0], np.ndarray), (
"The predictions should be multilabel"
)
logger.info(f"No confidence threshold - {len(y_test)} classified")
if is_multilabel:
confusion_matrix = metrics.multilabel_confusion_matrix(y_test, y_pred)
else:
confusion_matrix = metrics.confusion_matrix(
y_test, y_pred, labels=self.class_names
)
print(
classification_report_imbalanced(
y_test, y_pred, labels=self.class_names
)
)
report = classification_report_imbalanced_values(
y_test, y_pred, labels=self.class_names
)
tracking_metrics["report"] = report
print_labeled_confusion_matrix(
confusion_matrix, self.class_names, is_multilabel=is_multilabel
)
tracking_metrics["confusion_matrix"] = confusion_matrix.tolist()
confidence_thresholds = [0.6, 0.7, 0.8, 0.9]
if is_binary:
confidence_thresholds = [0.1, 0.2, 0.3, 0.4] + confidence_thresholds
# Evaluate results on the test set for some confidence thresholds.
for confidence_threshold in confidence_thresholds:
y_pred_probas = self.clf.predict_proba(X_test)
confidence_class_names = self.class_names + ["__NOT_CLASSIFIED__"]
y_pred_filter = []
classified_indices = []
for i in range(0, len(y_test)):
if not is_binary:
argmax = np.argmax(y_pred_probas[i])
else:
argmax = 1 if y_pred_probas[i][1] > confidence_threshold else 0
if y_pred_probas[i][argmax] < confidence_threshold:
if not is_multilabel:
y_pred_filter.append("__NOT_CLASSIFIED__")
continue
classified_indices.append(i)
if is_multilabel:
y_pred_filter.append(y_pred[i])
else:
y_pred_filter.append(argmax)
if not is_multilabel:
y_pred_filter = np.array(y_pred_filter)
y_pred_filter[classified_indices] = self.le.inverse_transform(
np.array(y_pred_filter[classified_indices], dtype=int)
)
if is_multilabel:
classified_num = len(classified_indices)
else:
classified_num = sum(
1 for v in y_pred_filter if v != "__NOT_CLASSIFIED__"
)
logger.info(
f"\nConfidence threshold > {confidence_threshold} - {classified_num} classified"
)
if is_multilabel:
confusion_matrix = metrics.multilabel_confusion_matrix(
y_test[classified_indices], np.asarray(y_pred_filter)
)
else:
confusion_matrix = metrics.confusion_matrix(
y_test.astype(str),
y_pred_filter.astype(str),
labels=confidence_class_names,
)
print(
classification_report_imbalanced(
y_test.astype(str),
y_pred_filter.astype(str),
labels=confidence_class_names,
)
)
print_labeled_confusion_matrix(
confusion_matrix, confidence_class_names, is_multilabel=is_multilabel
)
self.evaluation()
if self.entire_dataset_training:
logger.info("Retraining on the entire dataset...")
X_train = X
y_train = y
logger.info(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
self.clf.fit(X_train, self.le.transform(y_train))
model_directory = self.__class__.__name__.lower()
makedirs(model_directory, exist_ok=True)
step_name, estimator = self.clf.steps.pop()
if issubclass(type(estimator), XGBModel):
xgboost_model_path = path.join(model_directory, "xgboost.ubj")
estimator.save_model(xgboost_model_path)
# Since we save the estimator separately, we need to reset it to
# prevent its data from being pickled with the pipeline.
hyperparameters = estimator.get_params()
estimator = estimator.__class__(**hyperparameters)
self.clf.steps.append((step_name, estimator))
model_path = path.join(model_directory, "model.pkl")
with open(model_path, "wb") as f:
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
if self.store_dataset:
with open(f"{self.__class__.__name__.lower()}_data_X", "wb") as f:
pickle.dump(X, f, protocol=pickle.HIGHEST_PROTOCOL)
with open(f"{self.__class__.__name__.lower()}_data_y", "wb") as f:
pickle.dump(y, f, protocol=pickle.HIGHEST_PROTOCOL)
return tracking_metrics