def pruneTree()

in causalml/inference/tree/uplift.pyx [0:0]


    def pruneTree(self, X, treatment_idx, y, tree, rule='maxAbsDiff', minGain=0.,
                  n_reg=0,
                  parentNodeSummary=None):
        """Prune one single tree node in the uplift model.
        Args
        ----
        X : ndarray, shape = [num_samples, num_features]
            An ndarray of the covariates used to train the uplift model.
        treatment_idx : array-like, shape = [num_samples]
            An array containing the treatment group index for each unit.
        y : array-like, shape = [num_samples]
            An array containing the outcome of interest for each unit.
        rule : string, optional (default = 'maxAbsDiff')
            The prune rules. Supported values are 'maxAbsDiff' for optimizing the maximum absolute difference, and
            'bestUplift' for optimizing the node-size weighted treatment effect.
        minGain : float, optional (default = 0.)
            The minimum gain required to make a tree node split. The children tree branches are trimmed if the actual
            split gain is less than the minimum gain.
        n_reg: int, optional (default=0)
            The regularization parameter defined in Rzepakowski et al. 2012, the weight (in terms of sample size) of the
            parent node influence on the child node, only effective for 'KL', 'ED', 'Chi', 'CTS' methods.
        parentNodeSummary : list of list, optional (default = None)
            Node summary statistics, [P(Y=1|T), N(T)] of the parent tree node.
        Returns
        -------
        self : object
        """
        # Current Node Summary for Validation Data Set
        currentNodeSummary = self.tree_node_summary(
            treatment_idx, y, min_samples_treatment=self.min_samples_treatment,
            n_reg=n_reg, parentNodeSummary=parentNodeSummary
        )
        tree.nodeSummary = currentNodeSummary
        # Divide sets for child nodes
        if (tree.trueBranch is None) or (tree.falseBranch is None):
            X_l, X_r, w_l, w_r, y_l, y_r = self.divideSet(X, treatment_idx, y, tree.col, tree.value)

            # recursive call for each branch
            if tree.trueBranch.results is None:
                self.pruneTree(X_l, w_l, y_l, tree.trueBranch, rule, minGain,
                               n_reg,
                               parentNodeSummary=currentNodeSummary)
            if tree.falseBranch.results is None:
                self.pruneTree(X_r, w_r, y_r, tree.falseBranch, rule, minGain,
                               n_reg,
                               parentNodeSummary=currentNodeSummary)

        # merge leaves (potentially)
        if (tree.trueBranch.results is not None and
            tree.falseBranch.results is not None):
            if rule == 'maxAbsDiff':
                # Current D
                if (tree.maxDiffTreatment in currentNodeSummary and
                    self.control_name in currentNodeSummary):
                    currentScoreD = tree.maxDiffSign * (currentNodeSummary[tree.maxDiffTreatment][0]
                                                        - currentNodeSummary[self.control_name][0])
                else:
                    currentScoreD = 0

                # trueBranch D
                trueNodeSummary = self.tree_node_summary(
                    w_l, y_l, min_samples_treatment=self.min_samples_treatment,
                    n_reg=n_reg, parentNodeSummary=currentNodeSummary
                )
                if (tree.trueBranch.maxDiffTreatment in trueNodeSummary and
                    self.control_name in trueNodeSummary):
                    trueScoreD = tree.trueBranch.maxDiffSign * (trueNodeSummary[tree.trueBranch.maxDiffTreatment][0]
                                                                - trueNodeSummary[self.control_name][0])
                    trueScoreD = (
                        trueScoreD
                        * (trueNodeSummary[tree.trueBranch.maxDiffTreatment][1]
                         + trueNodeSummary[self.control_name][1])
                        / (currentNodeSummary[tree.trueBranch.maxDiffTreatment][1]
                           + currentNodeSummary[self.control_name][1])
                    )
                else:
                    trueScoreD = 0

                # falseBranch D
                falseNodeSummary = self.tree_node_summary(
                    w_r, y_r, min_samples_treatment=self.min_samples_treatment,
                    n_reg=n_reg, parentNodeSummary=currentNodeSummary
                )
                if (tree.falseBranch.maxDiffTreatment in falseNodeSummary and
                    self.control_name in falseNodeSummary):
                    falseScoreD = (
                        tree.falseBranch.maxDiffSign *
                        (falseNodeSummary[tree.falseBranch.maxDiffTreatment][0]
                         - falseNodeSummary[self.control_name][0])
                    )

                    falseScoreD = (
                        falseScoreD *
                        (falseNodeSummary[tree.falseBranch.maxDiffTreatment][1]
                         + falseNodeSummary[self.control_name][1])
                        / (currentNodeSummary[tree.falseBranch.maxDiffTreatment][1]
                           + currentNodeSummary[self.control_name][1])
                    )
                else:
                    falseScoreD = 0

                if ((trueScoreD + falseScoreD) - currentScoreD <= minGain or
                    (trueScoreD + falseScoreD < 0.)):
                    tree.trueBranch, tree.falseBranch = None, None
                    tree.results = tree.backupResults

            elif rule == 'bestUplift':
                # Current D
                if (tree.bestTreatment in currentNodeSummary and
                    self.control_name in currentNodeSummary):
                    currentScoreD = (
                        currentNodeSummary[tree.bestTreatment][0]
                        - currentNodeSummary[self.control_name][0]
                    )
                else:
                    currentScoreD = 0

                # trueBranch D
                trueNodeSummary = self.tree_node_summary(
                    w_l, y_l, min_samples_treatment=self.min_samples_treatment,
                    n_reg=n_reg, parentNodeSummary=currentNodeSummary
                )
                if (tree.trueBranch.bestTreatment in trueNodeSummary and
                    self.control_name in trueNodeSummary):
                    trueScoreD = (
                        trueNodeSummary[tree.trueBranch.bestTreatment][0]
                        - trueNodeSummary[self.control_name][0]
                    )
                else:
                    trueScoreD = 0

                # falseBranch D
                falseNodeSummary = self.tree_node_summary(
                    w_r, y_r, min_samples_treatment=self.min_samples_treatment,
                    n_reg=n_reg, parentNodeSummary=currentNodeSummary
                )
                if (tree.falseBranch.bestTreatment in falseNodeSummary and
                    self.control_name in falseNodeSummary):
                    falseScoreD = (
                        falseNodeSummary[tree.falseBranch.bestTreatment][0]
                        - falseNodeSummary[self.control_name][0]
                    )
                else:
                    falseScoreD = 0
                gain = ((1. * len(y_l) / len(y) * trueScoreD
                         + 1. * len(y_r) / len(y) * falseScoreD)
                        - currentScoreD)
                if gain <= minGain or (trueScoreD + falseScoreD < 0.):
                    tree.trueBranch, tree.falseBranch = None, None
                    tree.results = tree.backupResults
        return self