def __init__()

in econml/orf/_ortho_forest.py [0:0]


    def __init__(self, *,
                 n_trees=500,
                 min_leaf_size=10, max_depth=10,
                 subsample_ratio=0.7,
                 bootstrap=False,
                 lambda_reg=0.01,
                 model_T='auto',
                 model_Y=WeightedLassoCVWrapper(cv=3),
                 model_T_final=None,
                 model_Y_final=None,
                 global_residualization=False,
                 global_res_cv=2,
                 discrete_treatment=False,
                 categories='auto',
                 n_jobs=-1,
                 backend='loky',
                 verbose=3,
                 batch_size='auto',
                 random_state=None):
        # Copy and/or define models
        self.lambda_reg = lambda_reg
        if model_T == 'auto':
            if discrete_treatment:
                model_T = LogisticRegressionCV(cv=3)
            else:
                model_T = WeightedLassoCVWrapper(cv=3)
        self.model_T = model_T
        self.model_Y = model_Y
        self.model_T_final = model_T_final
        self.model_Y_final = model_Y_final
        if self.model_T_final is None:
            self.model_T_final = clone(self.model_T, safe=False)
        if self.model_Y_final is None:
            self.model_Y_final = clone(self.model_Y, safe=False)
        if discrete_treatment:
            self.model_T = _RegressionWrapper(self.model_T)
            self.model_T_final = _RegressionWrapper(self.model_T_final)
        self.random_state = check_random_state(random_state)
        self.global_residualization = global_residualization
        self.global_res_cv = global_res_cv
        # Define nuisance estimators
        nuisance_estimator = _DMLOrthoForest_nuisance_estimator_generator(
            self.model_T, self.model_Y, self.random_state, second_stage=False,
            global_residualization=self.global_residualization, discrete_treatment=discrete_treatment)
        second_stage_nuisance_estimator = _DMLOrthoForest_nuisance_estimator_generator(
            self.model_T_final, self.model_Y_final, self.random_state, second_stage=True,
            global_residualization=self.global_residualization, discrete_treatment=discrete_treatment)
        # Define parameter estimators
        parameter_estimator = _DMLOrthoForest_parameter_estimator_func
        second_stage_parameter_estimator = _DMLOrthoForest_second_stage_parameter_estimator_gen(
            self.lambda_reg)
        # Define
        moment_and_mean_gradient_estimator = _DMLOrthoForest_moment_and_mean_gradient_estimator_func
        super().__init__(
            nuisance_estimator,
            second_stage_nuisance_estimator,
            parameter_estimator,
            second_stage_parameter_estimator,
            moment_and_mean_gradient_estimator,
            n_trees=n_trees,
            min_leaf_size=min_leaf_size,
            max_depth=max_depth,
            subsample_ratio=subsample_ratio,
            bootstrap=bootstrap,
            n_jobs=n_jobs,
            backend=backend,
            verbose=verbose,
            batch_size=batch_size,
            discrete_treatment=discrete_treatment,
            categories=categories,
            random_state=self.random_state)