in causalml/inference/tree/uplift.pyx [0:0]
def fit(self, X, treatment, y, X_val=None, treatment_val=None, y_val=None):
"""
Fit the UpliftRandomForestClassifier.
Args
----
X : ndarray, shape = [num_samples, num_features]
An ndarray of the covariates used to train the uplift model.
treatment : array-like, shape = [num_samples]
An array containing the treatment group for each unit.
y : array-like, shape = [num_samples]
An array containing the outcome of interest for each unit.
X_val : ndarray, shape = [num_samples, num_features]
An ndarray of the covariates used to valid the uplift model.
treatment_val : array-like, shape = [num_samples]
An array containing the validation treatment group for each unit.
y_val : array-like, shape = [num_samples]
An array containing the validation outcome of interest for each unit.
"""
random_state = check_random_state(self.random_state)
# Create forest
self.uplift_forest = [
UpliftTreeClassifier(
max_features=self.max_features, max_depth=self.max_depth,
min_samples_leaf=self.min_samples_leaf,
min_samples_treatment=self.min_samples_treatment,
n_reg=self.n_reg,
early_stopping_eval_diff_scale=self.early_stopping_eval_diff_scale,
evaluationFunction=self.evaluationFunction,
control_name=self.control_name,
normalization=self.normalization,
honesty=self.honesty,
random_state=random_state.randint(MAX_INT))
for _ in range(self.n_estimators)
]
# Get treatment group keys. self.classes_[0] is reserved for the control group.
treatment_groups = sorted([x for x in list(set(treatment)) if x != self.control_name])
self.classes_ = [self.control_name]
for tr in treatment_groups:
self.classes_.append(tr)
self.n_class = len(self.classes_)
self.uplift_forest = (
Parallel(n_jobs=self.n_jobs, prefer=self.joblib_prefer)
(delayed(self.bootstrap)(X, treatment, y, X_val, treatment_val, y_val, tree) for tree in self.uplift_forest)
)
all_importances = [tree.feature_importances_ for tree in self.uplift_forest]
self.feature_importances_ = np.mean(all_importances, axis=0)
self.feature_importances_ /= self.feature_importances_.sum() # normalize to add to 1