in econml/orf/_ortho_forest.py [0:0]
def __init__(self, *,
n_trees=500,
min_leaf_size=10, max_depth=10,
subsample_ratio=0.7,
bootstrap=False,
lambda_reg=0.01,
model_T='auto',
model_Y=WeightedLassoCVWrapper(cv=3),
model_T_final=None,
model_Y_final=None,
global_residualization=False,
global_res_cv=2,
discrete_treatment=False,
categories='auto',
n_jobs=-1,
backend='loky',
verbose=3,
batch_size='auto',
random_state=None):
# Copy and/or define models
self.lambda_reg = lambda_reg
if model_T == 'auto':
if discrete_treatment:
model_T = LogisticRegressionCV(cv=3)
else:
model_T = WeightedLassoCVWrapper(cv=3)
self.model_T = model_T
self.model_Y = model_Y
self.model_T_final = model_T_final
self.model_Y_final = model_Y_final
if self.model_T_final is None:
self.model_T_final = clone(self.model_T, safe=False)
if self.model_Y_final is None:
self.model_Y_final = clone(self.model_Y, safe=False)
if discrete_treatment:
self.model_T = _RegressionWrapper(self.model_T)
self.model_T_final = _RegressionWrapper(self.model_T_final)
self.random_state = check_random_state(random_state)
self.global_residualization = global_residualization
self.global_res_cv = global_res_cv
# Define nuisance estimators
nuisance_estimator = _DMLOrthoForest_nuisance_estimator_generator(
self.model_T, self.model_Y, self.random_state, second_stage=False,
global_residualization=self.global_residualization, discrete_treatment=discrete_treatment)
second_stage_nuisance_estimator = _DMLOrthoForest_nuisance_estimator_generator(
self.model_T_final, self.model_Y_final, self.random_state, second_stage=True,
global_residualization=self.global_residualization, discrete_treatment=discrete_treatment)
# Define parameter estimators
parameter_estimator = _DMLOrthoForest_parameter_estimator_func
second_stage_parameter_estimator = _DMLOrthoForest_second_stage_parameter_estimator_gen(
self.lambda_reg)
# Define
moment_and_mean_gradient_estimator = _DMLOrthoForest_moment_and_mean_gradient_estimator_func
super().__init__(
nuisance_estimator,
second_stage_nuisance_estimator,
parameter_estimator,
second_stage_parameter_estimator,
moment_and_mean_gradient_estimator,
n_trees=n_trees,
min_leaf_size=min_leaf_size,
max_depth=max_depth,
subsample_ratio=subsample_ratio,
bootstrap=bootstrap,
n_jobs=n_jobs,
backend=backend,
verbose=verbose,
batch_size=batch_size,
discrete_treatment=discrete_treatment,
categories=categories,
random_state=self.random_state)