def fit()

in econml/_ortho_learner.py [0:0]


    def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
            cache_values=False, inference=None, only_final=False, check_input=True):
        """
        Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.

        Parameters
        ----------
        Y: (n, d_y) matrix or vector of length n
            Outcomes for each sample
        T: (n, d_t) matrix or vector of length n
            Treatments for each sample
        X: optional (n, d_x) matrix or None (Default=None)
            Features for each sample
        W: optional (n, d_w) matrix or None (Default=None)
            Controls for each sample
        Z: optional (n, d_z) matrix or None (Default=None)
            Instruments for each sample
        sample_weight : (n,) array like, default None
            Individual weights for each sample. If None, it assumes equal weight.
        freq_weight: (n, ) array like of integers, default None
            Weight for the observation. Observation i is treated as the mean
            outcome of freq_weight[i] independent observations.
            When ``sample_var`` is not None, this should be provided.
        sample_var : {(n,), (n, d_y)} nd array like, default None
            Variance of the outcome(s) of the original freq_weight[i] observations that were used to
            compute the mean outcome represented by observation i.
        groups: (n,) vector, optional
            All rows corresponding to the same group will be kept together during splitting.
            If groups is not None, the cv argument passed to this class's initializer
            must support a 'groups' argument to its split method.
        cache_values: bool, default False
            Whether to cache the inputs and computed nuisances, which will allow refitting a different final model
        inference: string, :class:`.Inference` instance, or None
            Method for performing inference.  This estimator supports 'bootstrap'
            (or an instance of :class:`.BootstrapInference`).
        only_final: bool, defaul False
            Whether to fit the nuisance models or use the existing cached values
            Note. This parameter is only used internally by the `refit` method and should not be exposed
            publicly by overwrites of the `fit` method in public classes.
        check_input: bool, default True
            Whether to check if the input is valid
            Note. This parameter is only used internally by the `refit` method and should not be exposed
            publicly by overwrites of the `fit` method in public classes.

        Returns
        -------
        self : object
        """
        self._random_state = check_random_state(self.random_state)
        assert (freq_weight is None) == (
            sample_var is None), "Sample variances and frequency weights must be provided together!"
        if check_input:
            Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
                Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)
            self._check_input_dims(Y, T, X, W, Z, sample_weight, freq_weight, sample_var, groups)

        if not only_final:

            if self.discrete_treatment:
                categories = self.categories
                if categories != 'auto':
                    categories = [categories]  # OneHotEncoder expects a 2D array with features per column
                self.transformer = OneHotEncoder(categories=categories, sparse=False, drop='first')
                self.transformer.fit(reshape(T, (-1, 1)))
                self._d_t = (len(self.transformer.categories_[0]) - 1,)
            else:
                self.transformer = None

            if self.discrete_instrument:
                self.z_transformer = OneHotEncoder(categories='auto', sparse=False, drop='first')
                self.z_transformer.fit(reshape(Z, (-1, 1)))
            else:
                self.z_transformer = None

            all_nuisances = []
            fitted_inds = None
            if sample_weight is None:
                if freq_weight is not None:
                    sample_weight_nuisances = freq_weight
                else:
                    sample_weight_nuisances = None
            else:
                if freq_weight is not None:
                    sample_weight_nuisances = freq_weight * sample_weight
                else:
                    sample_weight_nuisances = sample_weight

            self._models_nuisance = []
            for idx in range(self.mc_iters or 1):
                nuisances, fitted_models, new_inds, scores = self._fit_nuisances(
                    Y, T, X, W, Z, sample_weight=sample_weight_nuisances, groups=groups)
                all_nuisances.append(nuisances)
                self._models_nuisance.append(fitted_models)
                if scores is None:
                    self.nuisance_scores_ = None
                else:
                    if idx == 0:
                        self.nuisance_scores_ = tuple([] for _ in scores)
                    for ind, score in enumerate(scores):
                        self.nuisance_scores_[ind].append(score)
                if fitted_inds is None:
                    fitted_inds = new_inds
                elif not np.array_equal(fitted_inds, new_inds):
                    raise AttributeError("Different indices were fit by different folds, so they cannot be aggregated")

            if self.mc_iters is not None:
                if self.mc_agg == 'mean':
                    nuisances = tuple(np.mean(nuisance_mc_variants, axis=0)
                                      for nuisance_mc_variants in zip(*all_nuisances))
                elif self.mc_agg == 'median':
                    nuisances = tuple(np.median(nuisance_mc_variants, axis=0)
                                      for nuisance_mc_variants in zip(*all_nuisances))
                else:
                    raise ValueError(
                        "Parameter `mc_agg` must be one of {'mean', 'median'}. Got {}".format(self.mc_agg))

            Y, T, X, W, Z, sample_weight, freq_weight, sample_var = (self._subinds_check_none(arr, fitted_inds)
                                                                     for arr in (Y, T, X, W, Z, sample_weight,
                                                                                 freq_weight, sample_var))
            nuisances = tuple([self._subinds_check_none(nuis, fitted_inds) for nuis in nuisances])
            self._cached_values = CachedValues(nuisances=nuisances,
                                               Y=Y, T=T, X=X, W=W, Z=Z,
                                               sample_weight=sample_weight,
                                               freq_weight=freq_weight,
                                               sample_var=sample_var,
                                               groups=groups) if cache_values else None
        else:
            nuisances = self._cached_values.nuisances
            # _d_t is altered by fit nuisances to what prefit does. So we need to perform the same
            # alteration even when we only want to fit_final.
            if self.transformer is not None:
                self._d_t = (len(self.transformer.categories_[0]) - 1,)

        self._fit_final(Y=Y,
                        T=self.transformer.transform(T.reshape((-1, 1))) if self.transformer is not None else T,
                        X=X, W=W, Z=Z,
                        nuisances=nuisances,
                        sample_weight=sample_weight,
                        freq_weight=freq_weight,
                        sample_var=sample_var,
                        groups=groups)

        return self