in tensorflow_decision_forests/keras/wrappers_pre_generated.py [0:0]
def __init__(
self,
task: Optional[TaskType] = core.Task.CLASSIFICATION,
features: Optional[List[core.FeatureUsage]] = None,
exclude_non_specified_features: Optional[bool] = False,
preprocessing: Optional["tf.keras.models.Functional"] = None,
postprocessing: Optional["tf.keras.models.Functional"] = None,
ranking_group: Optional[str] = None,
uplift_treatment: Optional[str] = None,
temp_directory: Optional[str] = None,
verbose: int = 1,
hyperparameter_template: Optional[str] = None,
advanced_arguments: Optional[AdvancedArguments] = None,
num_threads: Optional[int] = None,
name: Optional[str] = None,
max_vocab_count: Optional[int] = 2000,
try_resume_training: Optional[bool] = True,
check_dataset: Optional[bool] = True,
adapt_subsample_for_maximum_training_duration: Optional[bool] = False,
allow_na_conditions: Optional[bool] = False,
apply_link_function: Optional[bool] = True,
categorical_algorithm: Optional[str] = "CART",
categorical_set_split_greedy_sampling: Optional[float] = 0.1,
categorical_set_split_max_num_items: Optional[int] = -1,
categorical_set_split_min_item_frequency: Optional[int] = 1,
compute_permutation_variable_importance: Optional[bool] = False,
dart_dropout: Optional[float] = 0.01,
early_stopping: Optional[str] = "LOSS_INCREASE",
early_stopping_num_trees_look_ahead: Optional[int] = 30,
focal_loss_alpha: Optional[float] = 0.5,
focal_loss_gamma: Optional[float] = 2.0,
forest_extraction: Optional[str] = "MART",
goss_alpha: Optional[float] = 0.2,
goss_beta: Optional[float] = 0.1,
growing_strategy: Optional[str] = "LOCAL",
honest: Optional[bool] = False,
in_split_min_examples_check: Optional[bool] = True,
keep_non_leaf_label_distribution: Optional[bool] = True,
l1_regularization: Optional[float] = 0.0,
l2_categorical_regularization: Optional[float] = 1.0,
l2_regularization: Optional[float] = 0.0,
lambda_loss: Optional[float] = 1.0,
loss: Optional[str] = "DEFAULT",
max_depth: Optional[int] = 6,
max_num_nodes: Optional[int] = None,
maximum_model_size_in_memory_in_bytes: Optional[float] = -1.0,
maximum_training_duration_seconds: Optional[float] = -1.0,
min_examples: Optional[int] = 5,
missing_value_policy: Optional[str] = "GLOBAL_IMPUTATION",
num_candidate_attributes: Optional[int] = -1,
num_candidate_attributes_ratio: Optional[float] = -1.0,
num_trees: Optional[int] = 300,
random_seed: Optional[int] = 123456,
sampling_method: Optional[str] = "NONE",
selective_gradient_boosting_ratio: Optional[float] = 0.01,
shrinkage: Optional[float] = 0.1,
sorting_strategy: Optional[str] = "PRESORT",
sparse_oblique_normalization: Optional[str] = None,
sparse_oblique_num_projections_exponent: Optional[float] = None,
sparse_oblique_projection_density_factor: Optional[float] = None,
sparse_oblique_weights: Optional[str] = None,
split_axis: Optional[str] = "AXIS_ALIGNED",
subsample: Optional[float] = 1.0,
uplift_min_examples_in_treatment: Optional[int] = 5,
uplift_split_score: Optional[str] = "KULLBACK_LEIBLER",
use_hessian_gain: Optional[bool] = False,
validation_interval_in_trees: Optional[int] = 1,
validation_ratio: Optional[float] = 0.1,
explicit_args: Optional[Set[str]] = None):
learner_params = {
"adapt_subsample_for_maximum_training_duration":
adapt_subsample_for_maximum_training_duration,
"allow_na_conditions":
allow_na_conditions,
"apply_link_function":
apply_link_function,
"categorical_algorithm":
categorical_algorithm,
"categorical_set_split_greedy_sampling":
categorical_set_split_greedy_sampling,
"categorical_set_split_max_num_items":
categorical_set_split_max_num_items,
"categorical_set_split_min_item_frequency":
categorical_set_split_min_item_frequency,
"compute_permutation_variable_importance":
compute_permutation_variable_importance,
"dart_dropout":
dart_dropout,
"early_stopping":
early_stopping,
"early_stopping_num_trees_look_ahead":
early_stopping_num_trees_look_ahead,
"focal_loss_alpha":
focal_loss_alpha,
"focal_loss_gamma":
focal_loss_gamma,
"forest_extraction":
forest_extraction,
"goss_alpha":
goss_alpha,
"goss_beta":
goss_beta,
"growing_strategy":
growing_strategy,
"honest":
honest,
"in_split_min_examples_check":
in_split_min_examples_check,
"keep_non_leaf_label_distribution":
keep_non_leaf_label_distribution,
"l1_regularization":
l1_regularization,
"l2_categorical_regularization":
l2_categorical_regularization,
"l2_regularization":
l2_regularization,
"lambda_loss":
lambda_loss,
"loss":
loss,
"max_depth":
max_depth,
"max_num_nodes":
max_num_nodes,
"maximum_model_size_in_memory_in_bytes":
maximum_model_size_in_memory_in_bytes,
"maximum_training_duration_seconds":
maximum_training_duration_seconds,
"min_examples":
min_examples,
"missing_value_policy":
missing_value_policy,
"num_candidate_attributes":
num_candidate_attributes,
"num_candidate_attributes_ratio":
num_candidate_attributes_ratio,
"num_trees":
num_trees,
"random_seed":
random_seed,
"sampling_method":
sampling_method,
"selective_gradient_boosting_ratio":
selective_gradient_boosting_ratio,
"shrinkage":
shrinkage,
"sorting_strategy":
sorting_strategy,
"sparse_oblique_normalization":
sparse_oblique_normalization,
"sparse_oblique_num_projections_exponent":
sparse_oblique_num_projections_exponent,
"sparse_oblique_projection_density_factor":
sparse_oblique_projection_density_factor,
"sparse_oblique_weights":
sparse_oblique_weights,
"split_axis":
split_axis,
"subsample":
subsample,
"uplift_min_examples_in_treatment":
uplift_min_examples_in_treatment,
"uplift_split_score":
uplift_split_score,
"use_hessian_gain":
use_hessian_gain,
"validation_interval_in_trees":
validation_interval_in_trees,
"validation_ratio":
validation_ratio,
}
if hyperparameter_template is not None:
learner_params = core._apply_hp_template(
learner_params, hyperparameter_template,
self.predefined_hyperparameters(), explicit_args)
super(GradientBoostedTreesModel, self).__init__(
task=task,
learner="GRADIENT_BOOSTED_TREES",
learner_params=learner_params,
features=features,
exclude_non_specified_features=exclude_non_specified_features,
preprocessing=preprocessing,
postprocessing=postprocessing,
ranking_group=ranking_group,
uplift_treatment=uplift_treatment,
temp_directory=temp_directory,
verbose=verbose,
advanced_arguments=advanced_arguments,
num_threads=num_threads,
name=name,
max_vocab_count=max_vocab_count,
try_resume_training=try_resume_training,
check_dataset=check_dataset)