in causalml/inference/meta/xlearner.py [0:0]
def fit(self, X, treatment, y, p=None):
"""Fit the inference model.
Args:
X (np.matrix or np.array or pd.Dataframe): a feature matrix
treatment (np.array or pd.Series): a treatment vector
y (np.array or pd.Series): an outcome vector
p (np.ndarray or pd.Series or dict, optional): an array of propensity scores of float (0,1) in the
single-treatment case; or, a dictionary of treatment groups that map to propensity vectors of
float (0,1); if None will run ElasticNetPropensityModel() to generate the propensity scores.
"""
X, treatment, y = convert_pd_to_np(X, treatment, y)
check_treatment_vector(treatment, self.control_name)
self.t_groups = np.unique(treatment[treatment != self.control_name])
self.t_groups.sort()
if p is None:
self._set_propensity_models(X=X, treatment=treatment, y=y)
p = self.propensity
else:
p = self._format_p(p, self.t_groups)
self._classes = {group: i for i, group in enumerate(self.t_groups)}
self.models_mu_c = {group: deepcopy(self.model_mu_c) for group in self.t_groups}
self.models_mu_t = {group: deepcopy(self.model_mu_t) for group in self.t_groups}
self.models_tau_c = {
group: deepcopy(self.model_tau_c) for group in self.t_groups
}
self.models_tau_t = {
group: deepcopy(self.model_tau_t) for group in self.t_groups
}
self.vars_c = {}
self.vars_t = {}
for group in self.t_groups:
mask = (treatment == group) | (treatment == self.control_name)
treatment_filt = treatment[mask]
X_filt = X[mask]
y_filt = y[mask]
w = (treatment_filt == group).astype(int)
# Train outcome models
self.models_mu_c[group].fit(X_filt[w == 0], y_filt[w == 0])
self.models_mu_t[group].fit(X_filt[w == 1], y_filt[w == 1])
# Calculate variances and treatment effects
var_c = (
y_filt[w == 0]
- self.models_mu_c[group].predict_proba(X_filt[w == 0])[:, 1]
).var()
self.vars_c[group] = var_c
var_t = (
y_filt[w == 1]
- self.models_mu_t[group].predict_proba(X_filt[w == 1])[:, 1]
).var()
self.vars_t[group] = var_t
# Train treatment models
d_c = (
self.models_mu_t[group].predict_proba(X_filt[w == 0])[:, 1]
- y_filt[w == 0]
)
d_t = (
y_filt[w == 1]
- self.models_mu_c[group].predict_proba(X_filt[w == 1])[:, 1]
)
self.models_tau_c[group].fit(X_filt[w == 0], d_c)
self.models_tau_t[group].fit(X_filt[w == 1], d_t)