causalml/inference/tree/causal/_builder.pyx (406 lines of code) (raw):
# distutils: language = c++
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
# cython: language_level=3
# cython: linetrace=True
from libc.stdint cimport INTPTR_MAX
from libcpp cimport bool
from libcpp.stack cimport stack
from libcpp.vector cimport vector
from libcpp.algorithm cimport pop_heap
from libcpp.algorithm cimport push_heap
import numpy as np
cimport numpy as np
np.import_array()
cdef float64_t INFINITY = np.inf
cdef float64_t EPSILON = np.finfo('double').eps
cdef int IS_FIRST = 1
cdef int IS_NOT_FIRST = 0
cdef int IS_LEFT = 1
cdef int IS_NOT_LEFT = 0
TREE_LEAF = -1
TREE_UNDEFINED = -2
cdef intp_t _TREE_LEAF = TREE_LEAF
cdef intp_t _TREE_UNDEFINED = TREE_UNDEFINED
cdef class DepthFirstCausalTreeBuilder(TreeBuilder):
"""Build a decision tree in depth-first fashion.
DepthFirstTreeBuilder modified for causal trees
Source: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/tree/_tree.pyx
"""
def __cinit__(self, Splitter splitter, intp_t min_samples_split,
intp_t min_samples_leaf, float64_t min_weight_leaf,
intp_t max_depth, float64_t min_impurity_decrease):
self.splitter = splitter
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.min_weight_leaf = min_weight_leaf
self.max_depth = max_depth
self.min_impurity_decrease = min_impurity_decrease
cpdef build(self, Tree tree, object X,
const float64_t[:, ::1] y,
const int32_t[:] treatment,
const float64_t[:] sample_weight=None,
const unsigned char[::1] missing_values_in_feature_mask=None,
):
"""Build a decision tree from the training set (X, y)."""
# check input
X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight)
# Initial capacity
cdef intp_t init_capacity
if tree.max_depth <= 10:
init_capacity = <intp_t> (2 ** (tree.max_depth + 1)) - 1
else:
init_capacity = 2047
tree._resize(init_capacity)
# Parameters
cdef Splitter splitter = self.splitter
cdef intp_t max_depth = self.max_depth
cdef intp_t min_samples_leaf = self.min_samples_leaf
cdef float64_t min_weight_leaf = self.min_weight_leaf
cdef intp_t min_samples_split = self.min_samples_split
cdef float64_t min_impurity_decrease = self.min_impurity_decrease
# Recursive partition (without actual recursion)
splitter.init(X, y, treatment, sample_weight, missing_values_in_feature_mask)
cdef intp_t start
cdef intp_t end
cdef intp_t depth
cdef intp_t parent
cdef bint is_left
cdef intp_t n_node_samples = splitter.n_samples
cdef long tr_count
cdef long ct_count
cdef float64_t weighted_n_samples = splitter.weighted_n_samples
cdef float64_t weighted_n_node_samples
cdef SplitRecord split
cdef intp_t node_id
cdef float64_t middle_value
cdef float64_t left_child_min
cdef float64_t left_child_max
cdef float64_t right_child_min
cdef float64_t right_child_max
cdef intp_t n_constant_features
cdef bint is_leaf
cdef bint first = 1
cdef intp_t max_depth_seen = -1
cdef int rc = 0
cdef stack[StackRecord] builder_stack
cdef StackRecord stack_record
cdef ParentInfo parent_record
_init_parent_record(&parent_record)
with nogil:
# push root node onto stack
builder_stack.push({
"start": 0,
"end": n_node_samples,
"depth": 0,
"parent": _TREE_UNDEFINED,
"is_left": 0,
"impurity": INFINITY,
"n_constant_features": 0,
"lower_bound": -INFINITY,
"upper_bound": INFINITY,
})
while not builder_stack.empty():
stack_record = builder_stack.top()
builder_stack.pop()
start = stack_record.start
end = stack_record.end
depth = stack_record.depth
parent = stack_record.parent
is_left = stack_record.is_left
parent_record.impurity = stack_record.impurity
parent_record.n_constant_features = stack_record.n_constant_features
parent_record.lower_bound = stack_record.lower_bound
parent_record.upper_bound = stack_record.upper_bound
n_node_samples = end - start
splitter.node_reset(start, end, &weighted_n_node_samples)
with gil:
# TODO: Get tr_count and ct_count without gil
tr_count = <long> splitter.criterion.state["node"]["tr_count"]
ct_count = <long> splitter.criterion.state["node"]["ct_count"]
is_leaf = (depth >= max_depth or
n_node_samples < min_samples_split or
n_node_samples < 2 * min_samples_leaf or
tr_count < min_samples_split // 2 or
ct_count < min_samples_split // 2 or
tr_count < min_samples_leaf or
ct_count < min_samples_leaf or
weighted_n_node_samples < 2 * min_weight_leaf)
if first:
parent_record.impurity = splitter.node_impurity()
first = 0
if not is_leaf:
splitter.node_split(&parent_record, &split,)
is_leaf = (is_leaf or split.pos >= end or
(split.improvement + EPSILON < min_impurity_decrease))
node_id = tree._add_node(parent, is_left, is_leaf, split.feature,
split.threshold, parent_record.impurity,
n_node_samples, weighted_n_node_samples,
split.missing_go_to_left)
if node_id == INTPTR_MAX:
rc = -1
break
# Store value for all nodes, to facilitate tree/model
# inspection and interpretation
splitter.node_value(tree.value + node_id * tree.value_stride)
if splitter.with_monotonic_cst:
splitter.clip_node_value(tree.value + node_id * tree.value_stride, parent_record.lower_bound, parent_record.upper_bound)
if not is_leaf:
if (
not splitter.with_monotonic_cst or
splitter.monotonic_cst[split.feature] == 0
):
# Split on a feature with no monotonicity constraint
# Current bounds must always be propagated to both children.
# If a monotonic constraint is active, bounds are used in
# node value clipping.
left_child_min = right_child_min = parent_record.lower_bound
left_child_max = right_child_max = parent_record.upper_bound
elif splitter.monotonic_cst[split.feature] == 1:
# Split on a feature with monotonic increase constraint
left_child_min = parent_record.lower_bound
right_child_max = parent_record.upper_bound
# Lower bound for right child and upper bound for left child
# are set to the same value.
middle_value = splitter.criterion.middle_value()
right_child_min = middle_value
left_child_max = middle_value
else: # i.e. splitter.monotonic_cst[split.feature] == -1
# Split on a feature with monotonic decrease constraint
right_child_min = parent_record.lower_bound
left_child_max = parent_record.upper_bound
# Lower bound for left child and upper bound for right child
# are set to the same value.
middle_value = splitter.criterion.middle_value()
left_child_min = middle_value
right_child_max = middle_value
# Push right child on stack
builder_stack.push({
"start": split.pos,
"end": end,
"depth": depth + 1,
"parent": node_id,
"is_left": 0,
"impurity": split.impurity_right,
"n_constant_features": parent_record.n_constant_features,
"lower_bound": right_child_min,
"upper_bound": right_child_max,
})
# Push left child on stack
builder_stack.push({
"start": start,
"end": split.pos,
"depth": depth + 1,
"parent": node_id,
"is_left": 1,
"impurity": split.impurity_left,
"n_constant_features": parent_record.n_constant_features,
"lower_bound": left_child_min,
"upper_bound": left_child_max,
})
if depth > max_depth_seen:
max_depth_seen = depth
if rc >= 0:
rc = tree._resize_c(tree.node_count)
if rc >= 0:
tree.max_depth = max_depth_seen
if rc == -1:
raise MemoryError()
cdef inline bool _compare_records(
const FrontierRecord& left,
const FrontierRecord& right,
):
return left.improvement < right.improvement
cdef inline void _add_to_frontier(
FrontierRecord rec,
vector[FrontierRecord]& frontier,
) noexcept nogil:
"""Adds record `rec` to the priority queue `frontier`."""
frontier.push_back(rec)
push_heap(frontier.begin(), frontier.end(), &_compare_records)
cdef class BestFirstCausalTreeBuilder(TreeBuilder):
"""Build a decision tree in best-first fashion.
The best node to expand is given by the node at the frontier that has the highest impurity improvement.
BestFirstCausalTreeBuilder modified for causal trees
Source: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/tree/_tree.pyx
"""
cdef intp_t max_leaf_nodes
def __cinit__(self, Splitter splitter, intp_t min_samples_split,
intp_t min_samples_leaf, min_weight_leaf,
intp_t max_depth, intp_t max_leaf_nodes,
float64_t min_impurity_decrease):
self.splitter = splitter
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.min_weight_leaf = min_weight_leaf
self.max_depth = max_depth
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_decrease = min_impurity_decrease
cpdef build(
self,
Tree tree,
object X,
const float64_t[:, ::1] y,
const int32_t[:] treatment,
const float64_t[:] sample_weight=None,
const unsigned char[::1] missing_values_in_feature_mask=None,
):
"""Build a decision tree from the training set (X, y)."""
# check input
X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight)
# Parameters
cdef Splitter splitter = self.splitter
cdef intp_t max_leaf_nodes = self.max_leaf_nodes
cdef intp_t min_samples_leaf = self.min_samples_leaf
cdef float64_t min_weight_leaf = self.min_weight_leaf
cdef intp_t min_samples_split = self.min_samples_split
# Recursive partition (without actual recursion)
splitter.init(X, y, treatment, sample_weight, missing_values_in_feature_mask)
cdef vector[FrontierRecord] frontier
cdef FrontierRecord record
cdef FrontierRecord split_node_left
cdef FrontierRecord split_node_right
cdef float64_t left_child_min
cdef float64_t left_child_max
cdef float64_t right_child_min
cdef float64_t right_child_max
cdef intp_t n_node_samples = splitter.n_samples
cdef intp_t max_split_nodes = max_leaf_nodes - 1
cdef bint is_leaf
cdef intp_t max_depth_seen = -1
cdef int rc = 0
cdef Node* node
cdef ParentInfo parent_record
_init_parent_record(&parent_record)
# Initial capacity
cdef intp_t init_capacity = max_split_nodes + max_leaf_nodes
tree._resize(init_capacity)
with nogil:
# add root to frontier
rc = self._add_split_node(
splitter=splitter,
tree=tree,
start=0,
end=n_node_samples,
is_first=IS_FIRST,
is_left=IS_LEFT,
parent=NULL,
depth=0,
parent_record=&parent_record,
res=&split_node_left,
)
if rc >= 0:
_add_to_frontier(split_node_left, frontier)
while not frontier.empty():
pop_heap(frontier.begin(), frontier.end(), &_compare_records)
record = frontier.back()
frontier.pop_back()
node = &tree.nodes[record.node_id]
is_leaf = (record.is_leaf or max_split_nodes <= 0)
if is_leaf:
# Node is not expandable; set node as leaf
node.left_child = _TREE_LEAF
node.right_child = _TREE_LEAF
node.feature = _TREE_UNDEFINED
node.threshold = _TREE_UNDEFINED
else:
# Node is expandable
if (
not splitter.with_monotonic_cst or
splitter.monotonic_cst[node.feature] == 0
):
# Split on a feature with no monotonicity constraint
# Current bounds must always be propagated to both children.
# If a monotonic constraint is active, bounds are used in
# node value clipping.
left_child_min = right_child_min = record.lower_bound
left_child_max = right_child_max = record.upper_bound
elif splitter.monotonic_cst[node.feature] == 1:
# Split on a feature with monotonic increase constraint
left_child_min = record.lower_bound
right_child_max = record.upper_bound
# Lower bound for right child and upper bound for left child
# are set to the same value.
right_child_min = record.middle_value
left_child_max = record.middle_value
else: # i.e. splitter.monotonic_cst[split.feature] == -1
# Split on a feature with monotonic decrease constraint
right_child_min = record.lower_bound
left_child_max = record.upper_bound
# Lower bound for left child and upper bound for right child
# are set to the same value.
left_child_min = record.middle_value
right_child_max = record.middle_value
# Decrement number of split nodes available
max_split_nodes -= 1
# Compute left split node
parent_record.lower_bound = left_child_min
parent_record.upper_bound = left_child_max
parent_record.impurity = record.impurity_left
rc = self._add_split_node(
splitter=splitter,
tree=tree,
start=record.start,
end=record.pos,
is_first=IS_NOT_FIRST,
is_left=IS_LEFT,
parent=node,
depth=record.depth + 1,
parent_record=&parent_record,
res=&split_node_left,
)
if rc == -1:
break
# tree.nodes may have changed
node = &tree.nodes[record.node_id]
# Compute right split node
parent_record.lower_bound = right_child_min
parent_record.upper_bound = right_child_max
parent_record.impurity = record.impurity_right
rc = self._add_split_node(
splitter=splitter,
tree=tree,
start=record.pos,
end=record.end,
is_first=IS_NOT_FIRST,
is_left=IS_NOT_LEFT,
parent=node,
depth=record.depth + 1,
parent_record=&parent_record,
res=&split_node_right,
)
if rc == -1:
break
# Add nodes to queue
_add_to_frontier(split_node_left, frontier)
_add_to_frontier(split_node_right, frontier)
if record.depth > max_depth_seen:
max_depth_seen = record.depth
if rc >= 0:
rc = tree._resize_c(tree.node_count)
if rc >= 0:
tree.max_depth = max_depth_seen
if rc == -1:
raise MemoryError()
cdef inline int _add_split_node(
self,
Splitter splitter,
Tree tree,
intp_t start,
intp_t end,
bint is_first,
bint is_left,
Node* parent,
intp_t depth,
ParentInfo* parent_record,
FrontierRecord* res
) except -1 nogil:
"""Adds node w/ partition ``[start, end)`` to the frontier. """
cdef SplitRecord split
cdef intp_t node_id
cdef intp_t n_node_samples
cdef long tr_count
cdef long ct_count
cdef float64_t weighted_n_samples = splitter.weighted_n_samples
cdef float64_t min_impurity_decrease = self.min_impurity_decrease
cdef float64_t weighted_n_node_samples
cdef bint is_leaf
cdef intp_t n_left, n_right
cdef float64_t imp_diff
splitter.node_reset(start, end, &weighted_n_node_samples)
# reset n_constant_features for this specific split before beginning split search
parent_record.n_constant_features = 0
with gil:
# TODO: Get tr_count and ct_count without gil
tr_count = <long> splitter.criterion.state["node"]["tr_count"]
ct_count = <long> splitter.criterion.state["node"]["ct_count"]
if is_first:
parent_record.impurity = splitter.node_impurity()
n_node_samples = end - start
is_leaf = (depth >= self.max_depth or
n_node_samples < self.min_samples_split or
n_node_samples < 2 * self.min_samples_leaf or
tr_count < self.min_samples_split // 2 or
ct_count < self.min_samples_split // 2 or
tr_count < self.min_samples_leaf or
ct_count < self.min_samples_leaf or
weighted_n_node_samples < 2 * self.min_weight_leaf or parent_record.impurity <= EPSILON
)
if not is_leaf:
splitter.node_split(
parent_record,
&split
)
is_leaf = (is_leaf or split.pos >= end or
split.improvement + EPSILON < min_impurity_decrease)
node_id = tree._add_node(parent - tree.nodes
if parent != NULL
else _TREE_UNDEFINED,
is_left, is_leaf,
split.feature, split.threshold, parent_record.impurity,
n_node_samples, weighted_n_node_samples,
split.missing_go_to_left)
if node_id == INTPTR_MAX:
return -1
# compute values also for split nodes (might become leafs later).
splitter.node_value(tree.value + node_id * tree.value_stride)
if splitter.with_monotonic_cst:
splitter.clip_node_value(tree.value + node_id * tree.value_stride, parent_record.lower_bound, parent_record.upper_bound)
res.node_id = node_id
res.start = start
res.end = end
res.depth = depth
res.impurity = parent_record.impurity
res.lower_bound = parent_record.lower_bound
res.upper_bound = parent_record.upper_bound
res.middle_value = splitter.criterion.middle_value()
if not is_leaf:
# is split node
res.pos = split.pos
res.is_leaf = 0
res.improvement = split.improvement
res.impurity_left = split.impurity_left
res.impurity_right = split.impurity_right
else:
# is leaf => 0 improvement
res.pos = end
res.is_leaf = 1
res.improvement = 0.0
res.impurity_left = parent_record.impurity
res.impurity_right = parent_record.impurity
return 0