def estimate_ate()

in causalml/inference/meta/tmle.py [0:0]


    def estimate_ate(self, X, treatment, y, p, segment=None, return_ci=False):
        """Estimate the Average Treatment Effect (ATE).

        Args:
            X (np.matrix or np.array or pd.Dataframe): a feature matrix
            treatment (np.array or pd.Series): a treatment vector
            y (np.array or pd.Series): an outcome vector
            p (np.ndarray or pd.Series or dict): an array of propensity scores of float (0,1) in the single-treatment
                case; or, a dictionary of treatment groups that map to propensity vectors of float (0,1)
            segment (np.array, optional): An optional segment vector of int. If given, the ATE and its CI will be
                                          estimated for each segment.
            return_ci (bool, optional): Whether to return confidence intervals

        Returns:
            (tuple): The ATE and its confidence interval (LB, UB) for each treatment, t and segment, s
        """
        X, treatment, y = convert_pd_to_np(X, treatment, y)
        check_treatment_vector(treatment, self.control_name)
        self.t_groups = np.unique(treatment[treatment != self.control_name])
        self.t_groups.sort()

        check_p_conditions(p, self.t_groups)
        if isinstance(p, (np.ndarray, pd.Series)):
            treatment_name = self.t_groups[0]
            p = {treatment_name: convert_pd_to_np(p)}
        elif isinstance(p, dict):
            p = {
                treatment_name: convert_pd_to_np(_p) for treatment_name, _p in p.items()
            }

        ate = []
        ate_lb = []
        ate_ub = []

        for _, group in enumerate(self.t_groups):
            logger.info("Estimating ATE for group {}.".format(group))
            w_group = (treatment == group).astype(int)
            p_group = p[group]

            if self.calibrate_propensity:
                logger.info("Calibrating propensity scores.")
                p_group = calibrate(p_group, w_group)

            yhat_c = np.zeros_like(y, dtype=float)
            yhat_t = np.zeros_like(y, dtype=float)
            if self.cv:
                for i_fold, (i_trn, i_val) in enumerate(self.cv.split(X, y), 1):
                    logger.info("Training an outcome model for CV #{}".format(i_fold))
                    self.model_tau.fit(
                        np.hstack((X[i_trn], w_group[i_trn].reshape(-1, 1))), y[i_trn]
                    )

                    yhat_c[i_val] = self.model_tau.predict(
                        np.hstack((X[i_val], np.zeros((len(i_val), 1))))
                    )
                    yhat_t[i_val] = self.model_tau.predict(
                        np.hstack((X[i_val], np.ones((len(i_val), 1))))
                    )

            else:
                self.model_tau.fit(np.hstack((X, w_group.reshape(-1, 1))), y)

                yhat_c = self.model_tau.predict(np.hstack((X, np.zeros((len(y), 1)))))
                yhat_t = self.model_tau.predict(np.hstack((X, np.ones((len(y), 1)))))

            if segment is None:
                logger.info("Training the TMLE learner.")
                _ate, se = simple_tmle(y, w_group, yhat_c, yhat_t, p_group)
                _ate_lb = _ate - se * norm.ppf(1 - self.ate_alpha / 2)
                _ate_ub = _ate + se * norm.ppf(1 - self.ate_alpha / 2)
            else:
                assert (
                    segment.shape[0] == X.shape[0] and segment.ndim == 1
                ), "Segment must be the 1-d np.array of int."
                segments = np.unique(segment)

                _ate = []
                _ate_lb = []
                _ate_ub = []
                for s in sorted(segments):
                    logger.info("Training the TMLE learner for segment {}.".format(s))
                    filt = (segment == s) & (yhat_c < np.quantile(yhat_c, q=0.99))
                    _ate_s, se = simple_tmle(
                        y[filt],
                        w_group[filt],
                        yhat_c[filt],
                        yhat_t[filt],
                        p_group[filt],
                    )
                    _ate_lb_s = _ate_s - se * norm.ppf(1 - self.ate_alpha / 2)
                    _ate_ub_s = _ate_s + se * norm.ppf(1 - self.ate_alpha / 2)

                    _ate.append(_ate_s)
                    _ate_lb.append(_ate_lb_s)
                    _ate_ub.append(_ate_ub_s)

            ate.append(_ate)
            ate_lb.append(_ate_lb)
            ate_ub.append(_ate_ub)

        return np.array(ate), np.array(ate_lb), np.array(ate_ub)