def fit()

in econml/solutions/causal_analysis/_causal_analysis.py [0:0]


    def fit(self, X, y, warm_start=False):
        """
        Fits global and local causal effect models for each feature in feature_inds on the data

        Parameters
        ----------
        X : array-like
            Feature data
        y : array-like of shape (n,) or (n,1)
            Outcome. If classification=True, then y should take two values. Otherwise an error is raised
            that only binary classification is implemented for now.
            TODO. enable multi-class classification for y (post-MVP)
        warm_start : boolean, default False
            If False, train models for each feature in `feature_inds`.
            If True, train only models for features in `feature_inds` that had not already been trained by
            the previous call to `fit`, and for which neither the corresponding heterogeneity_inds, nor the
            automl flag have changed. If heterogeneity_inds have changed, then the final stage model of these features
            will be refit. If the automl flag has changed, then whole model is refit, despite the warm start flag.
        """

        # Validate inputs
        assert self.nuisance_models in ['automl', 'linear'], (
            "The only supported nuisance models are 'linear' and 'automl', "
            f"but was given {self.nuisance_models}")

        assert self.heterogeneity_model in ['linear', 'forest'], (
            "The only supported heterogeneity models are 'linear' and 'forest' but received "
            f"{self.heterogeneity_model}")

        assert np.ndim(X) == 2, f"X must be a 2-dimensional array, but here had shape {np.shape(X)}"

        assert iterable(self.feature_inds), f"feature_inds should be array-like, but got {self.feature_inds}"
        assert iterable(self.categorical), f"categorical should be array-like, but got {self.categorical}"
        assert self.heterogeneity_inds is None or iterable(self.heterogeneity_inds), (
            f"heterogeneity_inds should be None or array-like, but got {self.heterogeneity_inds}")
        assert self.feature_names is None or iterable(self.feature_names), (
            f"feature_names should be None or array-like, but got {self.feature_names}")
        assert self.categories == 'auto' or iterable(self.categories), (
            f"categories should be 'auto' or array-like, but got {self.categories}")

        # TODO: check compatibility of X and Y lengths

        if warm_start:
            if not hasattr(self, "_results"):
                # no previous fit, cancel warm start
                warm_start = False

            elif self._d_x != X.shape[1]:
                raise ValueError(
                    f"Can't warm start: previous X had {self._d_x} columns, new X has {X.shape[1]} columns")

        # work with numeric feature indices, so that we can easily compare with categorical ones
        train_inds = _get_column_indices(X, self.feature_inds)

        if len(train_inds) == 0:
            raise ValueError(
                "No features specified. At least one feature index must be specified so that a model can be trained.")

        heterogeneity_inds = self.heterogeneity_inds
        if heterogeneity_inds is None:
            heterogeneity_inds = [None for ind in train_inds]

        # if heterogeneity_inds is 1D, repeat it
        if heterogeneity_inds == [] or isinstance(heterogeneity_inds[0], (int, str, bool)):
            heterogeneity_inds = [heterogeneity_inds for _ in train_inds]

        # heterogeneity inds should be a 2D list of length same as train_inds
        elif heterogeneity_inds is not None and len(heterogeneity_inds) != len(train_inds):
            raise ValueError("Heterogeneity indexes should have the same number of entries, but here "
                             f" there were {len(heterogeneity_inds)} heterogeneity entries but "
                             f" {len(train_inds)} feature indices.")

        # replace None elements of heterogeneity_inds and ensure indices are numeric
        heterogeneity_inds = {ind: list(range(X.shape[1])) if hinds is None else _get_column_indices(X, hinds)
                              for ind, hinds in zip(train_inds, heterogeneity_inds)}

        if warm_start:
            train_y_model = False
            if self.nuisance_models != self.nuisance_models_:
                warnings.warn("warm_start will be ignored since the nuisance models have changed "
                              f"from {self.nuisance_models_} to {self.nuisance_models} since the previous call to fit")
                warm_start = False
                train_y_model = True

            if self.heterogeneity_model != self.heterogeneity_model_:
                warnings.warn("warm_start will be ignored since the heterogeneity model has changed "
                              f"from {self.heterogeneity_model_} to {self.heterogeneity_model} "
                              "since the previous call to fit")
                warm_start = False

            # TODO: bail out also if categorical columns, classification, random_state changed?
        else:
            train_y_model = True

        # TODO: should we also train a new model_y under any circumstances when warm_start is True?
        if warm_start:
            new_inds = [ind for ind in train_inds if (ind not in self._cache or
                                                      heterogeneity_inds[ind] != self._cache[ind][1].hinds)]
        else:
            new_inds = list(train_inds)

            self._cache = {}  # store mapping from feature to insights, results

            # train the Y model
            if train_y_model:
                # perform model selection for the Y model using all X, not on a per-column basis
                allX = ColumnTransformer([('encode',
                                           OneHotEncoder(
                                               drop='first', sparse=False),
                                           self.categorical)],
                                         remainder=StandardScaler()).fit_transform(X)

                if self.verbose > 0:
                    print("CausalAnalysis: performing model selection on overall Y model")

                if self.classification:
                    self._model_y = _first_stage_clf(allX, y, automl=self.nuisance_models == 'automl',
                                                     make_regressor=True,
                                                     random_state=self.random_state, verbose=self.verbose)
                else:
                    self._model_y = _first_stage_reg(allX, y, automl=self.nuisance_models == 'automl',
                                                     random_state=self.random_state, verbose=self.verbose)

        if self.classification:
            # now that we've trained the classifier and wrapped it, ensure that y is transformed to
            # work with the regression wrapper

            # we use column_or_1d to treat pd.Series and pd.DataFrame objects the same way as arrays
            y = column_or_1d(y).reshape(-1, 1)

            # note that this needs to happen after wrapping to generalize to the multi-class case,
            # since otherwise we'll have too many columns to be able to train a classifier
            y = OneHotEncoder(drop='first', sparse=False).fit_transform(y)

        assert y.ndim == 1 or y.shape[1] == 1, ("Multiclass classification isn't supported" if self.classification
                                                else "Only a single outcome is supported")

        self._vec_y = y.ndim == 1
        self._d_x = X.shape[1]

        # start with empty results and default shared insights
        self._results = []
        self._shared = _get_default_shared_insights_output()
        self._shared[_CausalInsightsConstants.InitArgsKey] = {
            'feature_inds': _sanitize(self.feature_inds),
            'categorical': _sanitize(self.categorical),
            'heterogeneity_inds': _sanitize(self.heterogeneity_inds),
            'feature_names': _sanitize(self.feature_names),
            'classification': _sanitize(self.classification),
            'upper_bound_on_cat_expansion': _sanitize(self.upper_bound_on_cat_expansion),
            'nuisance_models': _sanitize(self.nuisance_models),
            'heterogeneity_model': _sanitize(self.heterogeneity_model),
            'categories': _sanitize(self.categories),
            'n_jobs': _sanitize(self.n_jobs),
            'verbose': _sanitize(self.verbose),
            'random_state': _sanitize(self.random_state)
        }

        # convert categorical indicators to numeric indices
        categorical_inds = _get_column_indices(X, self.categorical)

        categories = self.categories
        if categories == 'auto':
            categories = ['auto' for _ in categorical_inds]
        else:
            assert len(categories) == len(categorical_inds), (
                "If categories is not 'auto', it must contain one entry per categorical column.  Instead, categories"
                f"has length {len(categories)} while there are {len(categorical_inds)} categorical columns.")

        # check for indices over the categorical expansion bound
        invalid_inds = getattr(self, 'untrained_feature_indices_', [])

        # assume we'll be able to train former failures this time; we'll add them back if not
        invalid_inds = [(ind, reason) for (ind, reason) in invalid_inds if ind not in new_inds]

        self._has_column_names = True
        if self.feature_names is None:
            if hasattr(X, "iloc"):
                feature_names = X.columns
            else:
                self._has_column_names = False
                feature_names = [f"x{i}" for i in range(X.shape[1])]
        else:
            feature_names = self.feature_names
        self.feature_names_ = feature_names

        min_counts = {}
        for ind in new_inds:
            column_text = self._format_col(ind)

            if ind in categorical_inds:
                cats, counts = np.unique(_safe_indexing(X, ind, axis=1), return_counts=True)
                min_ind = np.argmin(counts)
                n_cat = len(cats)
                if n_cat > self.upper_bound_on_cat_expansion:
                    warnings.warn(f"{column_text} has more than {self.upper_bound_on_cat_expansion} "
                                  f"values (found {n_cat}) so no heterogeneity model will be fit for it; "
                                  "increase 'upper_bound_on_cat_expansion' to change this behavior.")
                    # can't remove in place while iterating over new_inds, so store in separate list
                    invalid_inds.append((ind, 'upper_bound_on_cat_expansion'))

                elif counts[min_ind] < _CAT_LIMIT:
                    if self.skip_cat_limit_checks and (counts[min_ind] >= 5 or
                                                       (counts[min_ind] >= 2 and
                                                        self.heterogeneity_model != 'forest')):
                        # train the model, but warn
                        warnings.warn(f"{column_text}'s value {cats[min_ind]} has only {counts[min_ind]} instances in "
                                      f"the training dataset, which is less than the lower limit ({_CAT_LIMIT}). "
                                      "A model will still be fit because 'skip_cat_limit_checks' is True, "
                                      "but this model may not be robust.")
                        min_counts[ind] = counts[min_ind]
                    elif counts[min_ind] < 2 or (counts[min_ind] < 5 and self.heterogeneity_model == 'forest'):
                        # no model can be trained in this case since we need more folds
                        warnings.warn(f"{column_text}'s value {cats[min_ind]} has only {counts[min_ind]} instances in "
                                      "the training dataset, but linear heterogeneity models need at least 2 and "
                                      "forest heterogeneity models need at least 5 instances, so no model will be fit "
                                      "for this column")
                        invalid_inds.append((ind, 'cat_limit'))
                    else:
                        # don't train a model, but suggest workaround since there are enough instances of least
                        # populated class
                        warnings.warn(f"{column_text}'s value {cats[min_ind]} has only {counts[min_ind]} instances in "
                                      f"the training dataset, which is less than the lower limit ({_CAT_LIMIT}), "
                                      "so no heterogeneity model will be fit for it. This check can be turned off by "
                                      "setting 'skip_cat_limit_checks' to True, but that may result in an inaccurate "
                                      "model for this feature.")
                        invalid_inds.append((ind, 'cat_limit'))

        for (ind, _) in invalid_inds:
            new_inds.remove(ind)
            # also remove from train_inds so we don't try to access the result later
            train_inds.remove(ind)
            if len(train_inds) == 0:
                raise ValueError("No features remain; increase the upper_bound_on_cat_expansion and ensure that there "
                                 "are several instances of each categorical value so that at least "
                                 "one feature model can be trained.")

        # extract subset of names matching new columns
        new_feat_names = _safe_indexing(feature_names, new_inds)

        cache_updates = dict(zip(new_inds,
                                 joblib.Parallel(
                                     n_jobs=self.n_jobs,
                                     verbose=self.verbose
                                 )(joblib.delayed(_process_feature)(
                                     feat_name, feat_ind,
                                     self.verbose, categorical_inds, categories, heterogeneity_inds, min_counts, y, X,
                                     self.nuisance_models, self.heterogeneity_model, self.random_state, self._model_y,
                                     self.cv, self.mc_iters)
                                     for feat_name, feat_ind in zip(new_feat_names, new_inds))))

        # track indices where an exception was thrown, since we can't remove from dictionary while iterating
        inds_to_remove = []
        for ind, value in cache_updates.items():
            if isinstance(value, Exception):
                # don't want to cache this failed result
                inds_to_remove.append(ind)
                train_inds.remove(ind)
                invalid_inds.append((ind, value))

        for ind in inds_to_remove:
            del cache_updates[ind]

        self._cache.update(cache_updates)

        for ind in train_inds:
            dict_update, result = self._cache[ind]
            self._results.append(result)
            for k in dict_update:
                self._shared[k] += dict_update[k]

        invalid_inds.sort()
        self.untrained_feature_indices_ = invalid_inds
        self.trained_feature_indices_ = train_inds

        self.nuisance_models_ = self.nuisance_models
        self.heterogeneity_model_ = self.heterogeneity_model
        return self