in econml/_ortho_learner.py [0:0]
def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
cache_values=False, inference=None, only_final=False, check_input=True):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.
Parameters
----------
Y: (n, d_y) matrix or vector of length n
Outcomes for each sample
T: (n, d_t) matrix or vector of length n
Treatments for each sample
X: optional (n, d_x) matrix or None (Default=None)
Features for each sample
W: optional (n, d_w) matrix or None (Default=None)
Controls for each sample
Z: optional (n, d_z) matrix or None (Default=None)
Instruments for each sample
sample_weight : (n,) array like, default None
Individual weights for each sample. If None, it assumes equal weight.
freq_weight: (n, ) array like of integers, default None
Weight for the observation. Observation i is treated as the mean
outcome of freq_weight[i] independent observations.
When ``sample_var`` is not None, this should be provided.
sample_var : {(n,), (n, d_y)} nd array like, default None
Variance of the outcome(s) of the original freq_weight[i] observations that were used to
compute the mean outcome represented by observation i.
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the cv argument passed to this class's initializer
must support a 'groups' argument to its split method.
cache_values: bool, default False
Whether to cache the inputs and computed nuisances, which will allow refitting a different final model
inference: string, :class:`.Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of :class:`.BootstrapInference`).
only_final: bool, defaul False
Whether to fit the nuisance models or use the existing cached values
Note. This parameter is only used internally by the `refit` method and should not be exposed
publicly by overwrites of the `fit` method in public classes.
check_input: bool, default True
Whether to check if the input is valid
Note. This parameter is only used internally by the `refit` method and should not be exposed
publicly by overwrites of the `fit` method in public classes.
Returns
-------
self : object
"""
self._random_state = check_random_state(self.random_state)
assert (freq_weight is None) == (
sample_var is None), "Sample variances and frequency weights must be provided together!"
if check_input:
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)
self._check_input_dims(Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)
if not only_final:
if self.discrete_treatment:
categories = self.categories
if categories != 'auto':
categories = [categories] # OneHotEncoder expects a 2D array with features per column
self.transformer = OneHotEncoder(categories=categories, sparse=False, drop='first')
self.transformer.fit(reshape(T, (-1, 1)))
self._d_t = (len(self.transformer.categories_[0]) - 1,)
else:
self.transformer = None
if self.discrete_instrument:
self.z_transformer = OneHotEncoder(categories='auto', sparse=False, drop='first')
self.z_transformer.fit(reshape(Z, (-1, 1)))
else:
self.z_transformer = None
all_nuisances = []
fitted_inds = None
if sample_weight is None:
if freq_weight is not None:
sample_weight_nuisances = freq_weight
else:
sample_weight_nuisances = None
else:
if freq_weight is not None:
sample_weight_nuisances = freq_weight * sample_weight
else:
sample_weight_nuisances = sample_weight
self._models_nuisance = []
for idx in range(self.mc_iters or 1):
nuisances, fitted_models, new_inds, scores = self._fit_nuisances(
Y, T, X, W, Z, sample_weight=sample_weight_nuisances, groups=groups)
all_nuisances.append(nuisances)
self._models_nuisance.append(fitted_models)
if scores is None:
self.nuisance_scores_ = None
else:
if idx == 0:
self.nuisance_scores_ = tuple([] for _ in scores)
for ind, score in enumerate(scores):
self.nuisance_scores_[ind].append(score)
if fitted_inds is None:
fitted_inds = new_inds
elif not np.array_equal(fitted_inds, new_inds):
raise AttributeError("Different indices were fit by different folds, so they cannot be aggregated")
if self.mc_iters is not None:
if self.mc_agg == 'mean':
nuisances = tuple(np.mean(nuisance_mc_variants, axis=0)
for nuisance_mc_variants in zip(*all_nuisances))
elif self.mc_agg == 'median':
nuisances = tuple(np.median(nuisance_mc_variants, axis=0)
for nuisance_mc_variants in zip(*all_nuisances))
else:
raise ValueError(
"Parameter `mc_agg` must be one of {'mean', 'median'}. Got {}".format(self.mc_agg))
Y, T, X, W, Z, sample_weight, freq_weight, sample_var = (self._subinds_check_none(arr, fitted_inds)
for arr in (Y, T, X, W, Z, sample_weight,
freq_weight, sample_var))
nuisances = tuple([self._subinds_check_none(nuis, fitted_inds) for nuis in nuisances])
self._cached_values = CachedValues(nuisances=nuisances,
Y=Y, T=T, X=X, W=W, Z=Z,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var,
groups=groups) if cache_values else None
else:
nuisances = self._cached_values.nuisances
# _d_t is altered by fit nuisances to what prefit does. So we need to perform the same
# alteration even when we only want to fit_final.
if self.transformer is not None:
self._d_t = (len(self.transformer.categories_[0]) - 1,)
self._fit_final(Y=Y,
T=self.transformer.transform(T.reshape((-1, 1))) if self.transformer is not None else T,
X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var,
groups=groups)
return self