def _count_groups_distribution()

in causalml/inference/tree/causal/causaltree.py [0:0]


    def _count_groups_distribution(self, X: np.ndarray, treatment: np.ndarray) -> dict:
        """
        Count treatment, control distribution for tree nodes/leaves
        Args:
            X: (np.ndarray), feature matrix
            treatment: (np.ndarray), treatment vector
        Returns:
            dict: treatment groups for each tree node/leaves
        """
        check_is_fitted(self)

        self.is_leaves = get_tree_leaves_mask(self)
        groups_cnt = {
            idx: {group: 0 for group in self.treatment_groups}
            for idx in np.array(range(self.tree_.node_count))
        }
        node_indicators = self.tree_.decision_path(X.astype(np.float32))

        for sample_id in range(X.shape[0]):
            nodes_path = node_indicators.indices[
                node_indicators.indptr[sample_id] : node_indicators.indptr[
                    sample_id + 1
                ]
            ]

            if self.groups_cnt_mode == "leaves":
                groups_cnt[nodes_path[-1]][treatment[sample_id]] += 1
            elif self.groups_cnt_mode == "nodes":
                for node_id in nodes_path:
                    groups_cnt[node_id][treatment[sample_id]] += 1
        return groups_cnt