in causalml/inference/tree/models.py [0:0]
def pruneTree(self, X, treatment, y, tree, rule='maxAbsDiff', minGain=0.,
evaluationFunction=None, notify=False, n_reg=0,
parentNodeSummary=None):
"""Prune one single tree node in the uplift model.
Args
----
X : ndarray, shape = [num_samples, num_features]
An ndarray of the covariates used to train the uplift model.
treatment : array-like, shape = [num_samples]
An array containing the treatment group for each unit.
y : array-like, shape = [num_samples]
An array containing the outcome of interest for each unit.
rule : string, optional (default = 'maxAbsDiff')
The prune rules. Supported values are 'maxAbsDiff' for optimizing the maximum absolute difference, and
'bestUplift' for optimizing the node-size weighted treatment effect.
minGain : float, optional (default = 0.)
The minimum gain required to make a tree node split. The children tree branches are trimmed if the actual
split gain is less than the minimum gain.
evaluationFunction : string, optional (default = None)
Choose from one of the models: 'KL', 'ED', 'Chi', 'CTS'.
notify: bool, optional (default = False)
n_reg: int, optional (default=0)
The regularization parameter defined in Rzepakowski et al. 2012, the weight (in terms of sample size) of the
parent node influence on the child node, only effective for 'KL', 'ED', 'Chi', 'CTS' methods.
parentNodeSummary : dictionary, optional (default = None)
Node summary statistics of the parent tree node.
Returns
-------
self : object
"""
# Current Node Summary for Validation Data Set
currentNodeSummary = self.tree_node_summary(
treatment, y, min_samples_treatment=self.min_samples_treatment,
n_reg=n_reg, parentNodeSummary=parentNodeSummary
)
tree.nodeSummary = currentNodeSummary
# Divide sets for child nodes
X_l, X_r, w_l, w_r, y_l, y_r = self.divideSet(X, treatment, y, tree.col, tree.value)
# recursive call for each branch
if tree.trueBranch.results is None:
self.pruneTree(X_l, w_l, y_l, tree.trueBranch, rule, minGain,
evaluationFunction, notify, n_reg,
parentNodeSummary=currentNodeSummary)
if tree.falseBranch.results is None:
self.pruneTree(X_r, w_r, y_r, tree.falseBranch, rule, minGain,
evaluationFunction, notify, n_reg,
parentNodeSummary=currentNodeSummary)
# merge leaves (potentially)
if (tree.trueBranch.results is not None and
tree.falseBranch.results is not None):
if rule == 'maxAbsDiff':
# Current D
if (tree.maxDiffTreatment in currentNodeSummary and
self.control_name in currentNodeSummary):
currentScoreD = tree.maxDiffSign * (currentNodeSummary[tree.maxDiffTreatment][0]
- currentNodeSummary[self.control_name][0])
else:
currentScoreD = 0
# trueBranch D
trueNodeSummary = self.tree_node_summary(
w_l, y_l, min_samples_treatment=self.min_samples_treatment,
n_reg=n_reg, parentNodeSummary=currentNodeSummary
)
if (tree.trueBranch.maxDiffTreatment in trueNodeSummary and
self.control_name in trueNodeSummary):
trueScoreD = tree.trueBranch.maxDiffSign * (trueNodeSummary[tree.trueBranch.maxDiffTreatment][0]
- trueNodeSummary[self.control_name][0])
trueScoreD = (
trueScoreD
* (trueNodeSummary[tree.trueBranch.maxDiffTreatment][1]
+ trueNodeSummary[self.control_name][1])
/ (currentNodeSummary[tree.trueBranch.maxDiffTreatment][1]
+ currentNodeSummary[self.control_name][1])
)
else:
trueScoreD = 0
# falseBranch D
falseNodeSummary = self.tree_node_summary(
w_r, y_r, min_samples_treatment=self.min_samples_treatment,
n_reg=n_reg, parentNodeSummary=currentNodeSummary
)
if (tree.falseBranch.maxDiffTreatment in falseNodeSummary and
self.control_name in falseNodeSummary):
falseScoreD = (
tree.falseBranch.maxDiffSign *
(falseNodeSummary[tree.falseBranch.maxDiffTreatment][0]
- falseNodeSummary[self.control_name][0])
)
falseScoreD = (
falseScoreD *
(falseNodeSummary[tree.falseBranch.maxDiffTreatment][1]
+ falseNodeSummary[self.control_name][1])
/ (currentNodeSummary[tree.falseBranch.maxDiffTreatment][1]
+ currentNodeSummary[self.control_name][1])
)
else:
falseScoreD = 0
if ((trueScoreD + falseScoreD) - currentScoreD <= minGain or
(trueScoreD + falseScoreD < 0.)):
tree.trueBranch, tree.falseBranch = None, None
tree.results = tree.backupResults
elif rule == 'bestUplift':
# Current D
if (tree.bestTreatment in currentNodeSummary and
self.control_name in currentNodeSummary):
currentScoreD = (
currentNodeSummary[tree.bestTreatment][0]
- currentNodeSummary[self.control_name][0]
)
else:
currentScoreD = 0
# trueBranch D
trueNodeSummary = self.tree_node_summary(
w_l, y_l, min_samples_treatment=self.min_samples_treatment,
n_reg=n_reg, parentNodeSummary=currentNodeSummary
)
if (tree.trueBranch.bestTreatment in trueNodeSummary and
self.control_name in trueNodeSummary):
trueScoreD = (
trueNodeSummary[tree.trueBranch.bestTreatment][0]
- trueNodeSummary[self.control_name][0]
)
else:
trueScoreD = 0
# falseBranch D
falseNodeSummary = self.tree_node_summary(
w_r, y_r, min_samples_treatment=self.min_samples_treatment,
n_reg=n_reg, parentNodeSummary=currentNodeSummary
)
if (tree.falseBranch.bestTreatment in falseNodeSummary and
self.control_name in falseNodeSummary):
falseScoreD = (
falseNodeSummary[tree.falseBranch.bestTreatment][0]
- falseNodeSummary[self.control_name][0]
)
else:
falseScoreD = 0
gain = ((1. * len(y_l) / len(y) * trueScoreD
+ 1. * len(y_r) / len(y) * falseScoreD)
- currentScoreD)
if gain <= minGain or (trueScoreD + falseScoreD < 0.):
tree.trueBranch, tree.falseBranch = None, None
tree.results = tree.backupResults
return self