in tensorflow_decision_forests/keras/wrappers_pre_generated.py [0:0]
def __init__(self,
task: Optional[TaskType] = core.Task.CLASSIFICATION,
features: Optional[List[core.FeatureUsage]] = None,
exclude_non_specified_features: Optional[bool] = False,
preprocessing: Optional["tf.keras.models.Functional"] = None,
postprocessing: Optional["tf.keras.models.Functional"] = None,
ranking_group: Optional[str] = None,
uplift_treatment: Optional[str] = None,
temp_directory: Optional[str] = None,
verbose: int = 1,
hyperparameter_template: Optional[str] = None,
advanced_arguments: Optional[AdvancedArguments] = None,
num_threads: Optional[int] = None,
name: Optional[str] = None,
max_vocab_count: Optional[int] = 2000,
try_resume_training: Optional[bool] = True,
check_dataset: Optional[bool] = True,
allow_na_conditions: Optional[bool] = False,
categorical_algorithm: Optional[str] = "CART",
categorical_set_split_greedy_sampling: Optional[float] = 0.1,
categorical_set_split_max_num_items: Optional[int] = -1,
categorical_set_split_min_item_frequency: Optional[int] = 1,
growing_strategy: Optional[str] = "LOCAL",
honest: Optional[bool] = False,
in_split_min_examples_check: Optional[bool] = True,
keep_non_leaf_label_distribution: Optional[bool] = True,
max_depth: Optional[int] = 16,
max_num_nodes: Optional[int] = None,
maximum_model_size_in_memory_in_bytes: Optional[float] = -1.0,
maximum_training_duration_seconds: Optional[float] = -1.0,
min_examples: Optional[int] = 5,
missing_value_policy: Optional[str] = "GLOBAL_IMPUTATION",
num_candidate_attributes: Optional[int] = 0,
num_candidate_attributes_ratio: Optional[float] = -1.0,
random_seed: Optional[int] = 123456,
sorting_strategy: Optional[str] = "PRESORT",
sparse_oblique_normalization: Optional[str] = None,
sparse_oblique_num_projections_exponent: Optional[float] = None,
sparse_oblique_projection_density_factor: Optional[float] = None,
sparse_oblique_weights: Optional[str] = None,
split_axis: Optional[str] = "AXIS_ALIGNED",
uplift_min_examples_in_treatment: Optional[int] = 5,
uplift_split_score: Optional[str] = "KULLBACK_LEIBLER",
validation_ratio: Optional[float] = 0.1,
explicit_args: Optional[Set[str]] = None):
learner_params = {
"allow_na_conditions":
allow_na_conditions,
"categorical_algorithm":
categorical_algorithm,
"categorical_set_split_greedy_sampling":
categorical_set_split_greedy_sampling,
"categorical_set_split_max_num_items":
categorical_set_split_max_num_items,
"categorical_set_split_min_item_frequency":
categorical_set_split_min_item_frequency,
"growing_strategy":
growing_strategy,
"honest":
honest,
"in_split_min_examples_check":
in_split_min_examples_check,
"keep_non_leaf_label_distribution":
keep_non_leaf_label_distribution,
"max_depth":
max_depth,
"max_num_nodes":
max_num_nodes,
"maximum_model_size_in_memory_in_bytes":
maximum_model_size_in_memory_in_bytes,
"maximum_training_duration_seconds":
maximum_training_duration_seconds,
"min_examples":
min_examples,
"missing_value_policy":
missing_value_policy,
"num_candidate_attributes":
num_candidate_attributes,
"num_candidate_attributes_ratio":
num_candidate_attributes_ratio,
"random_seed":
random_seed,
"sorting_strategy":
sorting_strategy,
"sparse_oblique_normalization":
sparse_oblique_normalization,
"sparse_oblique_num_projections_exponent":
sparse_oblique_num_projections_exponent,
"sparse_oblique_projection_density_factor":
sparse_oblique_projection_density_factor,
"sparse_oblique_weights":
sparse_oblique_weights,
"split_axis":
split_axis,
"uplift_min_examples_in_treatment":
uplift_min_examples_in_treatment,
"uplift_split_score":
uplift_split_score,
"validation_ratio":
validation_ratio,
}
if hyperparameter_template is not None:
learner_params = core._apply_hp_template(
learner_params, hyperparameter_template,
self.predefined_hyperparameters(), explicit_args)
super(CartModel, self).__init__(
task=task,
learner="CART",
learner_params=learner_params,
features=features,
exclude_non_specified_features=exclude_non_specified_features,
preprocessing=preprocessing,
postprocessing=postprocessing,
ranking_group=ranking_group,
uplift_treatment=uplift_treatment,
temp_directory=temp_directory,
verbose=verbose,
advanced_arguments=advanced_arguments,
num_threads=num_threads,
name=name,
max_vocab_count=max_vocab_count,
try_resume_training=try_resume_training,
check_dataset=check_dataset)