def fit()

in causalml/inference/meta/drlearner.py [0:0]


    def fit(self, X, treatment, y, p=None, seed=None):
        """Fit the inference model.

        Args:
            X (np.matrix or np.array or pd.Dataframe): a feature matrix
            treatment (np.array or pd.Series): a treatment vector
            y (np.array or pd.Series): an outcome vector
            p (np.ndarray or pd.Series or dict, optional): an array of propensity scores of float (0,1) in the
                single-treatment case; or, a dictionary of treatment groups that map to propensity vectors of
                float (0,1); if None will run ElasticNetPropensityModel() to generate the propensity scores.
            seed (int): random seed for cross-fitting
        """
        X, treatment, y = convert_pd_to_np(X, treatment, y)
        check_treatment_vector(treatment, self.control_name)
        self.t_groups = np.unique(treatment[treatment != self.control_name])
        self.t_groups.sort()
        self._classes = {group: i for i, group in enumerate(self.t_groups)}

        # The estimator splits the data into 3 partitions for cross-fit on the propensity score estimation,
        # the outcome regression, and the treatment regression on the doubly robust estimates. The use of
        # the partitions is rotated so we do not lose on the sample size.
        cv = KFold(n_splits=3, shuffle=True, random_state=seed)
        split_indices = [index for _, index in cv.split(y)]

        self.models_mu_c = [
            deepcopy(self.model_mu_c),
            deepcopy(self.model_mu_c),
            deepcopy(self.model_mu_c),
        ]
        self.models_mu_t = {
            group: [
                deepcopy(self.model_mu_t),
                deepcopy(self.model_mu_t),
                deepcopy(self.model_mu_t),
            ]
            for group in self.t_groups
        }
        self.models_tau = {
            group: [
                deepcopy(self.model_tau),
                deepcopy(self.model_tau),
                deepcopy(self.model_tau),
            ]
            for group in self.t_groups
        }
        if p is None:
            self.propensity = {group: np.zeros(y.shape[0]) for group in self.t_groups}

        for ifold in range(3):
            treatment_idx = split_indices[ifold]
            outcome_idx = split_indices[(ifold + 1) % 3]
            tau_idx = split_indices[(ifold + 2) % 3]

            treatment_treat, treatment_out, treatment_tau = (
                treatment[treatment_idx],
                treatment[outcome_idx],
                treatment[tau_idx],
            )
            y_out, y_tau = y[outcome_idx], y[tau_idx]
            X_treat, X_out, X_tau = X[treatment_idx], X[outcome_idx], X[tau_idx]

            if p is None:
                logger.info("Generating propensity score")
                cur_p = dict()

                for group in self.t_groups:
                    mask = (treatment_treat == group) | (
                        treatment_treat == self.control_name
                    )
                    treatment_filt = treatment_treat[mask]
                    X_filt = X_treat[mask]
                    w_filt = (treatment_filt == group).astype(int)
                    w = (treatment_tau == group).astype(int)
                    cur_p[group], _ = compute_propensity_score(
                        X=X_filt, treatment=w_filt, X_pred=X_tau, treatment_pred=w
                    )
                    self.propensity[group][tau_idx] = cur_p[group]
            else:
                cur_p = dict()
                if isinstance(p, (np.ndarray, pd.Series)):
                    cur_p = {self.t_groups[0]: convert_pd_to_np(p[tau_idx])}
                else:
                    cur_p = {g: prop[tau_idx] for g, prop in p.items()}
                check_p_conditions(cur_p, self.t_groups)

            logger.info("Generate outcome regressions")
            self.models_mu_c[ifold].fit(
                X_out[treatment_out == self.control_name],
                y_out[treatment_out == self.control_name],
            )
            for group in self.t_groups:
                self.models_mu_t[group][ifold].fit(
                    X_out[treatment_out == group], y_out[treatment_out == group]
                )

            logger.info("Fit pseudo outcomes from the DR formula")

            for group in self.t_groups:
                mask = (treatment_tau == group) | (treatment_tau == self.control_name)
                treatment_filt = treatment_tau[mask]
                X_filt = X_tau[mask]
                y_filt = y_tau[mask]
                w_filt = (treatment_filt == group).astype(int)
                p_filt = cur_p[group][mask]
                mu_t = self.models_mu_t[group][ifold].predict(X_filt)
                mu_c = self.models_mu_c[ifold].predict(X_filt)
                dr = (
                    (w_filt - p_filt)
                    / p_filt
                    / (1 - p_filt)
                    * (y_filt - mu_t * w_filt - mu_c * (1 - w_filt))
                    + mu_t
                    - mu_c
                )
                self.models_tau[group][ifold].fit(X_filt, dr)