in causalml/inference/tree/uplift.pyx [0:0]
def classify(observations, tree, dataMissing=False):
'''
Classifies (prediction) the observations according to the tree.
Args
----
observations : list of list
The internal data format for the training data (combining X, Y, treatment).
dataMissing: boolean, optional (default = False)
An indicator for if data are missing or not.
Returns
-------
tree.results, tree.upliftScore :
The results in the leaf node.
'''
def classifyWithoutMissingData(observations, tree):
'''
Classifies (prediction) the observations according to the tree, assuming without missing data.
Args
----
observations : list of list
The internal data format for the training data (combining X, Y, treatment).
Returns
-------
tree.results, tree.upliftScore :
The results in the leaf node.
'''
if tree.results is not None: # leaf
return tree.results, tree.upliftScore
else:
v = observations[tree.col]
branch = None
if isinstance(v, int) or isinstance(v, float):
if v >= tree.value:
branch = tree.trueBranch
else:
branch = tree.falseBranch
else:
if v == tree.value:
branch = tree.trueBranch
else:
branch = tree.falseBranch
return classifyWithoutMissingData(observations, branch)
def classifyWithMissingData(observations, tree):
'''
Classifies (prediction) the observations according to the tree, assuming with missing data.
Args
----
observations : list of list
The internal data format for the training data (combining X, Y, treatment).
Returns
-------
tree.results, tree.upliftScore :
The results in the leaf node.
'''
if tree.results is not None: # leaf
return tree.results
else:
v = observations[tree.col]
if v is None:
tr = classifyWithMissingData(observations, tree.trueBranch)
fr = classifyWithMissingData(observations, tree.falseBranch)
tcount = sum(tr.values())
fcount = sum(fr.values())
tw = float(tcount) / (tcount + fcount)
fw = float(fcount) / (tcount + fcount)
# Problem description: http://blog.ludovf.net/python-collections-defaultdict/
result = defaultdict(int)
for k, v in tr.items():
result[k] += v * tw
for k, v in fr.items():
result[k] += v * fw
return dict(result)
else:
branch = None
if isinstance(v, int) or isinstance(v, float):
if v >= tree.value:
branch = tree.trueBranch
else:
branch = tree.falseBranch
else:
if v == tree.value:
branch = tree.trueBranch
else:
branch = tree.falseBranch
return classifyWithMissingData(observations, branch)
# function body
if dataMissing:
return classifyWithMissingData(observations, tree)
else:
return classifyWithoutMissingData(observations, tree)