in misc/CCSynth/CC/DataInsights/src/prose/datainsights/_assertion/_decision_tree_assertions.py [0:0]
def process(self):
# Don't grow too deep trees
if self.node_depth > self.max_tree_depth:
return
self.node_assertion = PcaAssertion.learn(
df=self.df,
max_col_in_slice=self.parameters.max_col_in_slice,
slice_col_overlap=self.parameters.slice_col_overlap,
max_row_in_slice=self.parameters.max_row_in_slice,
use_const_term=self.parameters.use_const_term,
standardize_pca=self.parameters.standardize_pca,
max_self_violation=self.parameters.max_self_violation,
cross_validate=self.parameters.cross_validate,
n_fold=self.parameters.n_fold,
num_invs_to_return=None,
)
# Too few data points to learn anything useful
if len(self.node_assertion.std_dev_all) == 0:
return
if self.node_depth + 1 > self.max_tree_depth:
return
# Split
categorical_columns = [
column
for column, dtype in self.df.dtypes.iteritems()
if (
not self.partition_on_categorical_attribute_only
or not np.issubdtype(dtype, np.number)
)
and (
self.max_unique_value_per_categorical_attribute is None
or len(self.df[column].unique())
<= self.max_unique_value_per_categorical_attribute
)
]
if len(categorical_columns) == 0:
return
best_col = BestColumnForSplit(
df=self.df,
cat_columns=categorical_columns,
child_assertion_increasing_factor=2,
min_cat_cols_to_check=10,
max_col_in_slice=self.parameters.max_col_in_slice,
slice_col_overlap=self.parameters.slice_col_overlap,
max_row_in_slice=self.parameters.max_row_in_slice,
use_const_term=self.parameters.use_const_term,
standardize_pca=self.parameters.standardize_pca,
max_self_violation=self.parameters.max_self_violation,
cross_validate=self.parameters.cross_validate,
n_fold=self.parameters.n_fold,
assertion_improvement_factor=self.parameters.assertion_improvement_factor,
).get_best_column_from_partitions()
if best_col is None:
return
for value in self.df[best_col].unique():
cur_constraint = ConjunctiveConstraint(
single_constraints=[
SingleConstraint(
column_name=best_col,
column_value=value,
relational_op=_RelationalOperators.EQUAL,
)
]
)
child = _Node(
self.df,
self.node_depth + 1,
constraint=cur_constraint,
parameters=self.parameters,
)
if child.number_of_invs > 0:
self.children.append(child)