def _gen_ortho_learner_model_nuisance()

in econml/iv/dr/_dr.py [0:0]


    def _gen_ortho_learner_model_nuisance(self):
        if self.model_y_xw == 'auto':
            model_y_xw = WeightedLassoCVWrapper(random_state=self.random_state)
        else:
            model_y_xw = clone(self.model_y_xw, safe=False)

        if self.model_t_xw == 'auto':
            if self.discrete_treatment:
                model_t_xw = LogisticRegressionCV(cv=WeightedStratifiedKFold(random_state=self.random_state),
                                                  random_state=self.random_state)
            else:
                model_t_xw = WeightedLassoCVWrapper(random_state=self.random_state)
        else:
            model_t_xw = clone(self.model_t_xw, safe=False)

        if self.projection:
            # this is a regression model since proj_t is probability
            if self.model_tz_xw == "auto":
                model_tz_xw = WeightedLassoCVWrapper(random_state=self.random_state)
            else:
                model_tz_xw = clone(self.model_tz_xw, safe=False)

            if self.model_t_xwz == 'auto':
                if self.discrete_treatment:
                    model_t_xwz = LogisticRegressionCV(cv=WeightedStratifiedKFold(random_state=self.random_state),
                                                       random_state=self.random_state)
                else:
                    model_t_xwz = WeightedLassoCVWrapper(random_state=self.random_state)
            else:
                model_t_xwz = clone(self.model_t_xwz, safe=False)

            return _BaseDRIVModelNuisance(self._gen_prel_model_effect(),
                                          _FirstStageWrapper(model_y_xw, True, self._gen_featurizer(), False, False),
                                          _FirstStageWrapper(model_t_xw, False, self._gen_featurizer(),
                                                             False, self.discrete_treatment),
                                          # outcome is continuous since proj_t is probability
                                          _FirstStageWrapper(model_tz_xw, False, self._gen_featurizer(), False,
                                                             False),
                                          _FirstStageWrapper(model_t_xwz, False, self._gen_featurizer(),
                                                             False, self.discrete_treatment),
                                          self.projection, self.discrete_treatment, self.discrete_instrument)

        else:
            if self.model_tz_xw == "auto":
                if self.discrete_treatment and self.discrete_instrument:
                    model_tz_xw = LogisticRegressionCV(cv=WeightedStratifiedKFold(random_state=self.random_state),
                                                       random_state=self.random_state)
                else:
                    model_tz_xw = WeightedLassoCVWrapper(random_state=self.random_state)
            else:
                model_tz_xw = clone(self.model_tz_xw, safe=False)

            if self.model_z_xw == 'auto':
                if self.discrete_instrument:
                    model_z_xw = LogisticRegressionCV(cv=WeightedStratifiedKFold(random_state=self.random_state),
                                                      random_state=self.random_state)
                else:
                    model_z_xw = WeightedLassoCVWrapper(random_state=self.random_state)
            else:
                model_z_xw = clone(self.model_z_xw, safe=False)

            return _BaseDRIVModelNuisance(self._gen_prel_model_effect(),
                                          _FirstStageWrapper(model_y_xw, True, self._gen_featurizer(), False, False),
                                          _FirstStageWrapper(model_t_xw, False, self._gen_featurizer(),
                                                             False, self.discrete_treatment),
                                          _FirstStageWrapper(model_tz_xw, False, self._gen_featurizer(), False,
                                                             self.discrete_treatment and self.discrete_instrument),
                                          _FirstStageWrapper(model_z_xw, False, self._gen_featurizer(),
                                                             False, self.discrete_instrument),
                                          self.projection, self.discrete_treatment, self.discrete_instrument)