def growDecisionTreeFrom()

in causalml/inference/tree/uplift.pyx [0:0]


    def growDecisionTreeFrom(self, X, treatment_idx, y, X_val, treatment_val_idx, y_val,
                             early_stopping_eval_diff_scale=1, max_depth=10,
                             min_samples_leaf=100, depth=1,
                             min_samples_treatment=10, n_reg=100,
                             parentNodeSummary_p=None):
        '''
        Train the uplift decision tree.

        Args
        ----
        X : ndarray, shape = [num_samples, num_features]
            An ndarray of the covariates used to train the uplift model.
        treatment_idx : array-like, shape = [num_samples]
            An array containing the treatment group idx for each unit.
            The dtype should be numpy.int8.
        y : array-like, shape = [num_samples]
            An array containing the outcome of interest for each unit.
        X_val : ndarray, shape = [num_samples, num_features]
            An ndarray of the covariates used to valid the uplift model.
        treatment_val_idx : array-like, shape = [num_samples]
            An array containing the validation treatment group idx for each unit.
        y_val : array-like, shape = [num_samples]
            An array containing the validation outcome of interest for each unit.
        max_depth: int, optional (default=10)
            The maximum depth of the tree.
        min_samples_leaf: int, optional (default=100)
            The minimum number of samples required to be split at a leaf node.
        depth : int, optional (default = 1)
            The current depth.
        min_samples_treatment: int, optional (default=10)
            The minimum number of samples required of the experiment group to be split at a leaf node.
        n_reg: int, optional (default=10)
            The regularization parameter defined in Rzepakowski et al. 2012,
            the weight (in terms of sample size) of the parent node influence
            on the child node, only effective for 'KL', 'ED', 'Chi', 'CTS' methods.
        parentNodeSummary_p : array-like, shape [n_class]
            Node summary probability statistics of the parent tree node.

        Returns
        -------
        object of DecisionTree class
        '''

        if len(X) == 0:
            return DecisionTree(classes_=self.classes_)

        assert treatment_idx.dtype == TR_TYPE
        assert y.dtype == Y_TYPE

        # some temporary buffers for node summaries
        cdef int n_class = self.n_class
        # buffers for group counts, right can be derived from total and left
        cdef np.ndarray[N_TYPE_t, ndim=1] left_count_arr = np.zeros(2 * self.n_class, dtype = N_TYPE)
        cdef np.ndarray[N_TYPE_t, ndim=1] right_count_arr = np.zeros(2 * self.n_class, dtype = N_TYPE)
        cdef np.ndarray[N_TYPE_t, ndim=1] total_count_arr = np.zeros(2 * self.n_class, dtype = N_TYPE)
        # for X_val if any, allocate if needed below
        cdef np.ndarray[N_TYPE_t, ndim=1] val_left_count_arr
        cdef np.ndarray[N_TYPE_t, ndim=1] val_right_count_arr
        cdef np.ndarray[N_TYPE_t, ndim=1] val_total_count_arr
        # buffers for node summary
        cdef np.ndarray[P_TYPE_t, ndim=1] cur_summary_p = np.zeros(self.n_class, dtype = P_TYPE)
        cdef np.ndarray[N_TYPE_t, ndim=1] cur_summary_n = np.zeros(self.n_class, dtype = N_TYPE)
        cdef np.ndarray[P_TYPE_t, ndim=1] left_summary_p = np.zeros(self.n_class, dtype = P_TYPE)
        cdef np.ndarray[N_TYPE_t, ndim=1] left_summary_n = np.zeros(self.n_class, dtype = N_TYPE)
        cdef np.ndarray[P_TYPE_t, ndim=1] right_summary_p = np.zeros(self.n_class, dtype = P_TYPE)
        cdef np.ndarray[N_TYPE_t, ndim=1] right_summary_n = np.zeros(self.n_class, dtype = N_TYPE)
        # for val left and right summary
        cdef np.ndarray[P_TYPE_t, ndim=1] val_left_summary_p = np.zeros(self.n_class, dtype = P_TYPE)
        cdef np.ndarray[N_TYPE_t, ndim=1] val_left_summary_n = np.zeros(self.n_class, dtype = N_TYPE)
        cdef np.ndarray[P_TYPE_t, ndim=1] val_right_summary_p = np.zeros(self.n_class, dtype = P_TYPE)
        cdef np.ndarray[N_TYPE_t, ndim=1] val_right_summary_n = np.zeros(self.n_class, dtype = N_TYPE)
        
        # dummy
        cdef int has_parent_summary = 0
        if parentNodeSummary_p is None:
            parent_summary_p = np.zeros(self.n_class, dtype = P_TYPE) # dummy for calling tree_node_summary_to_arr
            has_parent_summary = 0
        else:
            parent_summary_p = parentNodeSummary_p
            has_parent_summary = 1

        cdef int i = 0

        # preparation: fill in the total count, then for each
        # candidate split, we calculate the count for left branch, and
        # can derive count for right branch using the total count.

        # group_count_arr: [N(Y=0, T=0), N(Y=1, T=0), N(Y=0, T=1), N(Y=1, T=1), ...]
        group_uniqueCounts_to_arr(treatment_idx, y, total_count_arr)
        if X_val is not None:
            val_left_count_arr = np.zeros(2 * self.n_class, dtype = N_TYPE)
            val_right_count_arr = np.zeros(2 * self.n_class, dtype = N_TYPE)
            val_total_count_arr = np.zeros(2 * self.n_class, dtype = N_TYPE)
            group_uniqueCounts_to_arr(treatment_val_idx, y_val, val_total_count_arr)

        # Current node summary: [P(Y=1|T=i)...] and [N(T=i)...]
        self.tree_node_summary_from_counts(
            total_count_arr,
            cur_summary_p, cur_summary_n,
            parent_summary_p,
            has_parent_summary,
            min_samples_treatment=min_samples_treatment,
            n_reg=n_reg
            )

        # to reconstruct current node summary in list of list form, so
        # that the constructed tree follows previous format.

        # Current node summary: [[P(Y=1|T=i), N(T=i)]...]
        currentNodeSummary = []
        for i in range(n_class):
            currentNodeSummary.append([cur_summary_p[i], cur_summary_n[i]])
        #

        if self.evaluationFunction == self.evaluate_IT or self.evaluationFunction == self.evaluate_CIT:
            currentScore = 0
        else:
            currentScore = self.arr_eval_func(cur_summary_p, cur_summary_n)

        # Prune Stats:
        cdef P_TYPE_t maxAbsDiff = 0.0
        cdef P_TYPE_t maxDiff = -1.
        cdef int bestTreatment = 0       # treatment index for the control group, also used in returning the tree for this node
        cdef int suboptTreatment = 0     # treatment index for the control group
        cdef int maxDiffTreatment = 0    # treatment index for the control group, also used in returning the tree for this node
        maxDiffSign = 0 # also used in returning the tree for this node
        # adapted to new current node summary format
        cdef P_TYPE_t p_c = cur_summary_p[0]
        cdef N_TYPE_t n_c = cur_summary_n[0]
        cdef N_TYPE_t n_t = 0
        cdef int i_tr = 0
        cdef P_TYPE_t p_t = 0.0, diff = 0.0

        for i_tr in range(1, n_class):
            p_t = cur_summary_p[i_tr]
            # P(Y=1|T=t) - P(Y=1|T=0)
            diff = p_t - p_c
            if fabs(diff) >= maxAbsDiff:
                maxDiffTreatment = i_tr
                maxDiffSign = np.sign(diff)
                maxAbsDiff = fabs(diff)
            if diff >= maxDiff:
                maxDiff = diff
                suboptTreatment = i_tr
                if diff > 0:
                    bestTreatment = i_tr
        if maxDiff > 0:
            p_t = cur_summary_p[bestTreatment]
            n_t = cur_summary_n[bestTreatment]
        else:
            p_t = cur_summary_p[suboptTreatment]
            n_t = cur_summary_n[suboptTreatment]
        p_value = (1. - stats.norm.cdf(fabs(p_c - p_t) / sqrt(p_t * (1 - p_t) / n_t + p_c * (1 - p_c) / n_c))) * 2
        upliftScore = [maxDiff, p_value]

        bestGain = 0.0
        bestGainImp = 0.0
        bestAttribute = None
        # keep mostly scalar when finding best split, then get the structural value after finding the best split
        best_col = None
        best_value = None
        len_X = len(X)
        len_X_val = len(X_val) if X_val is not None else 0

        c_num_percentiles = [3, 5, 10, 20, 30, 50, 70, 80, 90, 95, 97]
        c_cat_percentiles = [10, 50, 90]

        # last column is the result/target column, 2nd to the last is the treatment group
        columnCount = X.shape[1]
        if (self.max_features and self.max_features > 0 and self.max_features <= columnCount):
            max_features = self.max_features
        else:
            max_features = columnCount

        for col in list(self.random_state_.choice(a=range(columnCount), size=max_features, replace=False)):
            columnValues = X[:, col]
            # unique values
            lsUnique = np.unique(columnValues)

            if np.issubdtype(lsUnique.dtype, np.number):
                is_split_by_gt = True
                if len(lsUnique) > 10:
                    lspercentile = np.percentile(columnValues, c_num_percentiles)
                else:
                    lspercentile = np.percentile(lsUnique, c_cat_percentiles)
                lsUnique = np.unique(lspercentile)
            else:
                # to split by equality check.
                is_split_by_gt = False

            for value in lsUnique:
                len_X_l = group_counts_by_divide(columnValues, value, is_split_by_gt, treatment_idx, y, left_count_arr)
                len_X_r = len_X - len_X_l

                # check the split validity on min_samples_leaf  372
                if (len_X_l < min_samples_leaf or len_X_r < min_samples_leaf):
                    continue
                # summarize notes
                # Gain -- Entropy or Gini
                p = float(len_X_l) / len_X

                # right branch group counts can be calculated from left branch counts and total counts
                for i in range(2 * n_class):
                    right_count_arr[i] = total_count_arr[i] - left_count_arr[i]

                # left and right node summary, into the temporary buffers {left,right}_summary_{p,n}
                self.tree_node_summary_from_counts(
                    left_count_arr,
                    left_summary_p, left_summary_n,
                    cur_summary_p,
                    1,
                    min_samples_treatment,
                    n_reg
                    )

                self.tree_node_summary_from_counts(
                    right_count_arr,
                    right_summary_p, right_summary_n,
                    cur_summary_p,
                    1,
                    min_samples_treatment,
                    n_reg
                    )

                if X_val is not None:
                    len_X_val_l = group_counts_by_divide(X_val[:, col], value, is_split_by_gt, treatment_val_idx, y_val, val_left_count_arr)

                    # right branch group counts can be calculated from left branch counts and total counts
                    for i in range(2 * n_class):
                        val_right_count_arr[i] = val_total_count_arr[i] - val_left_count_arr[i]

                    self.tree_node_summary_from_counts(
                        val_left_count_arr,
                        val_left_summary_p, val_left_summary_n,
                        cur_summary_p, # parentNodeSummary_p
                        1 # has_parent_summary
                    )

                    self.tree_node_summary_from_counts(
                        val_right_count_arr,
                        val_right_summary_p, val_right_summary_n,
                        cur_summary_p, # parentNodeSummary_p
                        1 # has_parent_summary
                    )

                    early_stopping_flag = False
                    for k in range(n_class):
                        if (abs(val_left_summary_p[k] - left_summary_p[k]) >
                                min(val_left_summary_p[k], left_summary_p[k])/early_stopping_eval_diff_scale or
                            abs(val_right_summary_p[k] - right_summary_p[k]) > 
                                min(val_right_summary_p[k], right_summary_p[k])/early_stopping_eval_diff_scale):
                            early_stopping_flag = True
                            break

                    if early_stopping_flag:
                        continue

                # check the split validity on min_samples_treatment
                node_mst = min(np.min(left_summary_n), np.min(right_summary_n))
                if node_mst < min_samples_treatment:
                    continue

                # evaluate the split
                if self.arr_eval_func == self.arr_evaluate_CTS:
                    leftScore1 = self.arr_eval_func(left_summary_p, left_summary_n)
                    rightScore2 = self.arr_eval_func(right_summary_p, right_summary_n)
                    gain = (currentScore - p * leftScore1 - (1 - p) * rightScore2)
                    gain_for_imp = (len_X * currentScore - len_X_l * leftScore1 - len_X_r * rightScore2)
                elif self.arr_eval_func == self.arr_evaluate_DDP:
                    leftScore1 = self.arr_eval_func(left_summary_p, left_summary_n)
                    rightScore2 = self.arr_eval_func(right_summary_p, right_summary_n)
                    gain = np.abs(leftScore1 - rightScore2)
                    gain_for_imp = np.abs(len_X_l * leftScore1 - len_X_r * rightScore2)
                elif self.arr_eval_func == self.arr_evaluate_IT:
                    gain = self.arr_eval_func(left_summary_p, left_summary_n, right_summary_p, right_summary_n)
                    gain_for_imp = gain * len_X
                elif self.arr_eval_func == self.arr_evaluate_CIT:
                    gain = self.arr_eval_func(cur_summary_p, cur_summary_n,
                                              left_summary_p, left_summary_n,
                                              right_summary_p, right_summary_n)
                    gain_for_imp = gain * len_X
                elif self.arr_eval_func == self.arr_evaluate_IDDP:
                    leftScore1 = self.arr_eval_func(left_summary_p, left_summary_n)
                    rightScore2 = self.arr_eval_func(right_summary_p, right_summary_n)
                    gain = np.abs(leftScore1 - rightScore2) - np.abs(currentScore)
                    gain_for_imp = (len_X_l * leftScore1 + len_X_r * rightScore2 - len_X * np.abs(currentScore))
                    if self.normalization:
                        # Normalize used divergence
                        currentDivergence = 2 * (gain + 1) / 3
                        norm_factor = self.arr_normI(cur_summary_n, left_summary_n, alpha=0.9, currentDivergence=currentDivergence)
                    else:
                        norm_factor = 1
                    gain = gain / norm_factor
                else:
                    leftScore1 = self.arr_eval_func(left_summary_p, left_summary_n)
                    rightScore2 = self.arr_eval_func(right_summary_p, right_summary_n)
                    gain = (p * leftScore1 + (1 - p) * rightScore2 - currentScore)
                    gain_for_imp = (len_X_l * leftScore1 + len_X_r * rightScore2 - len_X * currentScore)
                    if self.normalization:
                        norm_factor = self.arr_normI(cur_summary_n, left_summary_n, alpha=0.9)
                    else:
                        norm_factor = 1
                    gain = gain / norm_factor 
                if (gain > bestGain and len_X_l > min_samples_leaf and len_X_r > min_samples_leaf):
                    bestGain = gain
                    bestGainImp = gain_for_imp
                    best_col = col
                    best_value = value
        
        # after finding the best split col and value
        if best_col is not None:
            bestAttribute = (best_col, best_value)
            # re-calculate the divideSet
            X_l, X_r, w_l, w_r, y_l, y_r = self.divideSet(X, treatment_idx, y, best_col, best_value)
            if X_val is not None:
                X_val_l, X_val_r, w_val_l, w_val_r, y_val_l, y_val_r = self.divideSet(X_val, treatment_val_idx, y_val, best_col, best_value)
                best_set_left = [X_l, w_l, y_l, X_val_l, w_val_l, y_val_l]
                best_set_right = [X_r, w_r, y_r, X_val_r, w_val_r, y_val_r]
            else:
                best_set_left = [X_l, w_l, y_l, None, None, None]
                best_set_right = [X_r, w_r, y_r, None, None, None]

        dcY = {'impurity': '%.3f' % currentScore, 'samples': '%d' % len(X)}
        # Add treatment size
        dcY['group_size'] = ''
        for i, summary in enumerate(currentNodeSummary):
            dcY['group_size'] += ' ' + self.classes_[i] + ': ' + str(summary[1])
        dcY['upliftScore'] = [round(upliftScore[0], 4), round(upliftScore[1], 4)]
        dcY['matchScore'] = round(upliftScore[0], 4)

        if bestGain > 0 and depth < max_depth:
            self.feature_imp_dict[bestAttribute[0]] += bestGainImp
            trueBranch = self.growDecisionTreeFrom(
                *best_set_left, self.early_stopping_eval_diff_scale, max_depth, min_samples_leaf,
                depth + 1, min_samples_treatment=min_samples_treatment,
                n_reg=n_reg, parentNodeSummary_p=cur_summary_p
            )
            falseBranch = self.growDecisionTreeFrom(
                *best_set_right, self.early_stopping_eval_diff_scale, max_depth, min_samples_leaf,
                depth + 1, min_samples_treatment=min_samples_treatment,
                n_reg=n_reg, parentNodeSummary_p=cur_summary_p
            )

            return DecisionTree(
                classes_=self.classes_,
                col=bestAttribute[0], value=bestAttribute[1],
                trueBranch=trueBranch, falseBranch=falseBranch, summary=dcY,
                maxDiffTreatment=maxDiffTreatment, maxDiffSign=maxDiffSign,
                nodeSummary=currentNodeSummary,
                backupResults=self.uplift_classification_results(treatment_idx, y),
                bestTreatment=bestTreatment, upliftScore=upliftScore
            )
        else:
            if self.evaluationFunction == self.evaluate_CTS:
                return DecisionTree(
                    classes_=self.classes_,
                    results=self.uplift_classification_results(treatment_idx, y),
                    summary=dcY, nodeSummary=currentNodeSummary,
                    bestTreatment=bestTreatment, upliftScore=upliftScore
                )
            else:
                return DecisionTree(
                    classes_=self.classes_,
                    results=self.uplift_classification_results(treatment_idx, y),
                    summary=dcY, maxDiffTreatment=maxDiffTreatment,
                    maxDiffSign=maxDiffSign, nodeSummary=currentNodeSummary,
                    bestTreatment=bestTreatment, upliftScore=upliftScore
                )