in causalml/inference/tree/uplift.pyx [0:0]
def pruneTree(self, X, treatment_idx, y, tree, rule='maxAbsDiff', minGain=0.,
n_reg=0,
parentNodeSummary=None):
"""Prune one single tree node in the uplift model.
Args
----
X : ndarray, shape = [num_samples, num_features]
An ndarray of the covariates used to train the uplift model.
treatment_idx : array-like, shape = [num_samples]
An array containing the treatment group index for each unit.
y : array-like, shape = [num_samples]
An array containing the outcome of interest for each unit.
rule : string, optional (default = 'maxAbsDiff')
The prune rules. Supported values are 'maxAbsDiff' for optimizing the maximum absolute difference, and
'bestUplift' for optimizing the node-size weighted treatment effect.
minGain : float, optional (default = 0.)
The minimum gain required to make a tree node split. The children tree branches are trimmed if the actual
split gain is less than the minimum gain.
n_reg: int, optional (default=0)
The regularization parameter defined in Rzepakowski et al. 2012, the weight (in terms of sample size) of the
parent node influence on the child node, only effective for 'KL', 'ED', 'Chi', 'CTS' methods.
parentNodeSummary : list of list, optional (default = None)
Node summary statistics, [P(Y=1|T), N(T)] of the parent tree node.
Returns
-------
self : object
"""
# Current Node Summary for Validation Data Set
currentNodeSummary = self.tree_node_summary(
treatment_idx, y, min_samples_treatment=self.min_samples_treatment,
n_reg=n_reg, parentNodeSummary=parentNodeSummary
)
tree.nodeSummary = currentNodeSummary
# Divide sets for child nodes
if (tree.trueBranch is None) or (tree.falseBranch is None):
X_l, X_r, w_l, w_r, y_l, y_r = self.divideSet(X, treatment_idx, y, tree.col, tree.value)
# recursive call for each branch
if tree.trueBranch.results is None:
self.pruneTree(X_l, w_l, y_l, tree.trueBranch, rule, minGain,
n_reg,
parentNodeSummary=currentNodeSummary)
if tree.falseBranch.results is None:
self.pruneTree(X_r, w_r, y_r, tree.falseBranch, rule, minGain,
n_reg,
parentNodeSummary=currentNodeSummary)
# merge leaves (potentially)
if (tree.trueBranch.results is not None and
tree.falseBranch.results is not None):
if rule == 'maxAbsDiff':
# Current D
if (tree.maxDiffTreatment in currentNodeSummary and
self.control_name in currentNodeSummary):
currentScoreD = tree.maxDiffSign * (currentNodeSummary[tree.maxDiffTreatment][0]
- currentNodeSummary[self.control_name][0])
else:
currentScoreD = 0
# trueBranch D
trueNodeSummary = self.tree_node_summary(
w_l, y_l, min_samples_treatment=self.min_samples_treatment,
n_reg=n_reg, parentNodeSummary=currentNodeSummary
)
if (tree.trueBranch.maxDiffTreatment in trueNodeSummary and
self.control_name in trueNodeSummary):
trueScoreD = tree.trueBranch.maxDiffSign * (trueNodeSummary[tree.trueBranch.maxDiffTreatment][0]
- trueNodeSummary[self.control_name][0])
trueScoreD = (
trueScoreD
* (trueNodeSummary[tree.trueBranch.maxDiffTreatment][1]
+ trueNodeSummary[self.control_name][1])
/ (currentNodeSummary[tree.trueBranch.maxDiffTreatment][1]
+ currentNodeSummary[self.control_name][1])
)
else:
trueScoreD = 0
# falseBranch D
falseNodeSummary = self.tree_node_summary(
w_r, y_r, min_samples_treatment=self.min_samples_treatment,
n_reg=n_reg, parentNodeSummary=currentNodeSummary
)
if (tree.falseBranch.maxDiffTreatment in falseNodeSummary and
self.control_name in falseNodeSummary):
falseScoreD = (
tree.falseBranch.maxDiffSign *
(falseNodeSummary[tree.falseBranch.maxDiffTreatment][0]
- falseNodeSummary[self.control_name][0])
)
falseScoreD = (
falseScoreD *
(falseNodeSummary[tree.falseBranch.maxDiffTreatment][1]
+ falseNodeSummary[self.control_name][1])
/ (currentNodeSummary[tree.falseBranch.maxDiffTreatment][1]
+ currentNodeSummary[self.control_name][1])
)
else:
falseScoreD = 0
if ((trueScoreD + falseScoreD) - currentScoreD <= minGain or
(trueScoreD + falseScoreD < 0.)):
tree.trueBranch, tree.falseBranch = None, None
tree.results = tree.backupResults
elif rule == 'bestUplift':
# Current D
if (tree.bestTreatment in currentNodeSummary and
self.control_name in currentNodeSummary):
currentScoreD = (
currentNodeSummary[tree.bestTreatment][0]
- currentNodeSummary[self.control_name][0]
)
else:
currentScoreD = 0
# trueBranch D
trueNodeSummary = self.tree_node_summary(
w_l, y_l, min_samples_treatment=self.min_samples_treatment,
n_reg=n_reg, parentNodeSummary=currentNodeSummary
)
if (tree.trueBranch.bestTreatment in trueNodeSummary and
self.control_name in trueNodeSummary):
trueScoreD = (
trueNodeSummary[tree.trueBranch.bestTreatment][0]
- trueNodeSummary[self.control_name][0]
)
else:
trueScoreD = 0
# falseBranch D
falseNodeSummary = self.tree_node_summary(
w_r, y_r, min_samples_treatment=self.min_samples_treatment,
n_reg=n_reg, parentNodeSummary=currentNodeSummary
)
if (tree.falseBranch.bestTreatment in falseNodeSummary and
self.control_name in falseNodeSummary):
falseScoreD = (
falseNodeSummary[tree.falseBranch.bestTreatment][0]
- falseNodeSummary[self.control_name][0]
)
else:
falseScoreD = 0
gain = ((1. * len(y_l) / len(y) * trueScoreD
+ 1. * len(y_r) / len(y) * falseScoreD)
- currentScoreD)
if gain <= minGain or (trueScoreD + falseScoreD < 0.):
tree.trueBranch, tree.falseBranch = None, None
tree.results = tree.backupResults
return self