def _split_exposure_hajek()

in causalPartition.py [0:0]


    def _split_exposure_hajek(self, node_id, df, probabilities, feature_set, max_attempt, eps, delta, 
                              outcome, rules, N, current_mse, criteria={'non_trivial_reduction': 0},
                             first_split_treatment=True):
        """
        the actual splitting implementation for separate tree; 
        by recursion
        """
        b_feature = ''
        b_threshold = 0
        b_left = None
        b_right = None
        b_average_left_hajek = 0
        b_average_right_hajek = 0
        b_mse = 10000000000.0  # a very large mse

        ranges = {}
        # enumerate each feature
        for feature in feature_set:
            gc.collect()
            # find a more compact region
            upper = 1.
            lower = 0.
            for rule in rules:
                # rules: list of tuples to describe the decision rules
                # tuples(feature, 0/1: lower or upper bound, value)
                if rule[0] == feature:
                    if rule[1] == 0:
                        lower = np.maximum(rule[2], lower)
                    else:
                        upper = np.minimum(rule[2], upper)
            if lower >= upper:
                continue
            
            for k in range(max_attempt):
                if first_split_treatment and node_id == 1:
                    if feature != self.treatment or k != 0:
                        continue
                
                threshold = np.random.uniform(lower, upper)  # randomly select a threshold, left < , right >
                # make sure it is a valid split --- each observation should have non-trial (>eps) probability to belong to each partition
                cz_l = self._contain_zero(probabilities, rules+[(feature, 0, threshold)], eps, delta)
                cz_r = self._contain_zero(probabilities, rules+[(feature, 1, threshold)], eps, delta)
                if np.mean(cz_l) > delta or np.mean(cz_r) > delta:
                    continue
                    # if (almost) positivity can't be satisfied
                    
                idxs_left = np.product([df[key] <= th for key, sign, th in rules if sign == 0] + \
                       [df[key] > th for key, sign, th in rules if sign == 1] + \
                        [df[feature] <= threshold],
                       axis=0) > 0

                idxs_right = np.product([df[key] <= th for key, sign, th in rules if sign == 0] + \
                       [df[key] > th for key, sign, th in rules if sign == 1] + \
                        [df[feature] > threshold],
                       axis=0) > 0

                left = df[idxs_left]
                right = df[idxs_right]
                
                # generalized propensity score (probability of belonging in an exposure condition)
                propensities_left = np.mean(np.product([probabilities[key][idxs_left] <= th for key, sign, th in rules if sign == 0] + \
                           [probabilities[key][idxs_left] > th for key, sign, th in rules if sign == 1] + \
                            [probabilities[feature][idxs_left] <= threshold],
                           axis=0) > 0, axis=1)

                # generalized propensity score (probability of belonging in an exposure condition)
                propensities_right = np.mean(np.product([probabilities[key][idxs_right] <= th for key, sign, th in rules if sign == 0] + \
                           [probabilities[key][idxs_right] > th for key, sign, th in rules if sign == 1] + \
                            [probabilities[feature][idxs_right] > threshold],
                           axis=0) > 0, axis=1)
                # again, filter small propensities data points (usually should not filter or filter very few)
                
                if len(left) == 0 or len(right) == 0:
                    continue
                
                filter_left = propensities_left > 0
                left = left[filter_left]
                propensities_left = propensities_left[filter_left]
                
                filter_right = propensities_right > 0
                right = right[filter_right]
                propensities_right = propensities_right[filter_right]
                
                mod_left = sm.WLS(left[outcome], np.ones(len(left)), weights=1.0 / propensities_left)
                mod_right = sm.WLS(right[outcome], np.ones(len(right)), weights=1.0 / propensities_right)
                
                res_left = mod_left.fit()
                res_right = mod_right.fit()
                
                average_left_hajek = res_left.params[0] 
                average_right_hajek = res_right.params[0]

                average_left_hajek_se = self._hajek_se(left, propensities_left, outcome)
                average_right_hajek_se = self._hajek_se(right, propensities_right, outcome)

                mse_left = np.sum((1.0 / propensities_left) * ((res_left.resid) ** 2))
                mse_right = np.sum((1.0 / propensities_right) * ((res_right.resid) ** 2))
                mse = mse_left * len(left)/(len(left)+len(right)) + mse_right * len(right)/(len(left)+len(right))
                
                if mse < b_mse:
                    flag = True
                    assert len(criteria) > 0
                    if 'non_trivial_reduction' in criteria:
                        if not (mse < current_mse - criteria['non_trivial_reduction']):
                            flag = False
                    if 'reasonable_propensity' in criteria:
                        if not (np.abs(np.sum(1.0 / propensities_left)/len(df) - 1.0) <= criteria['reasonable_propensity'] \
                                and \
                                np.abs(np.sum(1.0 / propensities_right)/len(df) - 1.0) <= criteria['reasonable_propensity'] \
                               ):
                            flag = False
                    if 'separate_reduction' in criteria:
                        if not (mse_left < current_mse and mse_right < current_mse):
                            flag = False
                    if 'min_leaf_size' in criteria:
                        if not (len(left) >= criteria['min_leaf_size'] and len(right) >= criteria['min_leaf_size']):
                            flag = False
                    if flag:
                        b_feature = feature
                        b_mse = mse
                        b_mse_left = mse_left
                        b_mse_right = mse_right
                        b_threshold = threshold
                        b_average_left_hajek = average_left_hajek
                        b_average_right_hajek = average_right_hajek
                        b_average_left_hajek_se = average_left_hajek_se
                        b_average_right_hajek_se = average_right_hajek_se
                        b_left_den = np.sum(1.0 / propensities_left)
                        b_right_den = np.sum(1.0 / propensities_right)
                        b_left = left
                        b_right = right
                        b_left_rules = rules + [(feature, 0, threshold)]
                        b_right_rules = rules + [(feature, 1, threshold)]

        result = {}
        if b_feature != '':
            # if find a valid partition
            result_left = self._split_exposure_hajek(node_id*2, df, probabilities, feature_set, max_attempt, eps, delta, 
                                                     outcome, b_left_rules, len(b_left), b_mse_left, criteria)
            result_right = self._split_exposure_hajek(node_id*2+1, df, probabilities, feature_set, max_attempt, eps, delta, 
                                                      outcome, b_right_rules, len(b_right), b_mse_right, criteria)
            result['mse'] = result_left['mse'] * 1.0 * len(b_left)/(len(b_left)+len(b_right)) + \
                        result_right['mse'] * 1.0 * len(b_right)/(len(b_left)+len(b_right))
            result['feature'] = b_feature
            result['threshold'] = b_threshold
            result_left['hajek'] = b_average_left_hajek
            result_right['hajek'] = b_average_right_hajek
            result_left['hajek_se'] = b_average_left_hajek_se
            result_right['hajek_se'] = b_average_right_hajek_se
            result_left['N'] = len(b_left)
            result_right['N'] = len(b_right)
            result_left['den'] = b_left_den
            result_right['den'] = b_right_den
            result['left_result'] = result_left
            result['right_result'] = result_right
            return result
        else:
            result['mse'] = current_mse
            return result