def predict()

in causalml/inference/iv/drivlearner.py [0:0]


    def predict(self, X, treatment=None, y=None, return_components=False, verbose=True):
        """Predict treatment effects.

        Args:
            X (np.matrix or np.array or pd.Dataframe): a feature matrix
            treatment (np.array or pd.Series, optional): a treatment vector
            y (np.array or pd.Series, optional): an outcome vector
            verbose (bool, optional): whether to output progress logs
        Returns:
            (numpy.ndarray): Predictions of treatment effects for compliers, i.e. those individuals
                who take the treatment only if they are assigned.
        """
        X, treatment, y = convert_pd_to_np(X, treatment, y)

        te = np.zeros((X.shape[0], self.t_groups.shape[0]))
        yhat_cs = {}
        yhat_ts = {}

        for i, group in enumerate(self.t_groups):
            models_tau = self.models_tau[group]
            _te = np.r_[[model.predict(X) for model in models_tau]].mean(axis=0)
            te[:, i] = np.ravel(_te)
            yhat_cs[group] = np.r_[
                [model.predict(X) for model in self.models_mu_c[group]]
            ].mean(axis=0)
            yhat_ts[group] = np.r_[
                [model.predict(X) for model in self.models_mu_t[group]]
            ].mean(axis=0)

            if (y is not None) and (treatment is not None) and verbose:
                mask = (treatment == group) | (treatment == self.control_name)
                treatment_filt = treatment[mask]
                X_filt = X[mask]
                y_filt = y[mask]
                w = (treatment_filt == group).astype(int)

                yhat = np.zeros_like(y_filt, dtype=float)
                yhat[w == 0] = yhat_cs[group][mask][w == 0]
                yhat[w == 1] = yhat_ts[group][mask][w == 1]

                logger.info("Error metrics for group {}".format(group))
                regression_metrics(y_filt, yhat, w)

        if not return_components:
            return te
        else:
            return te, yhat_cs, yhat_ts