in causalml/inference/tree/causal/causaltree.py [0:0]
def _count_groups_distribution(self, X: np.ndarray, treatment: np.ndarray) -> dict:
"""
Count treatment, control distribution for tree nodes/leaves
Args:
X: (np.ndarray), feature matrix
treatment: (np.ndarray), treatment vector
Returns:
dict: treatment groups for each tree node/leaves
"""
check_is_fitted(self)
self.is_leaves = get_tree_leaves_mask(self)
groups_cnt = {
idx: {group: 0 for group in self.treatment_groups}
for idx in np.array(range(self.tree_.node_count))
}
node_indicators = self.tree_.decision_path(X.astype(np.float32))
for sample_id in range(X.shape[0]):
nodes_path = node_indicators.indices[
node_indicators.indptr[sample_id] : node_indicators.indptr[
sample_id + 1
]
]
if self.groups_cnt_mode == "leaves":
groups_cnt[nodes_path[-1]][treatment[sample_id]] += 1
elif self.groups_cnt_mode == "nodes":
for node_id in nodes_path:
groups_cnt[node_id][treatment[sample_id]] += 1
return groups_cnt