in tensorflow_decision_forests/component/py_tree/condition.py [0:0]
def core_condition_to_condition(
core_condition: decision_tree_pb2.NodeCondition,
dataspec: data_spec_pb2.DataSpecification) -> AbstractCondition:
"""Converts a condition from the core to python format."""
condition_type = core_condition.condition
attribute = dataspec_lib.make_simple_column_spec(dataspec,
core_condition.attribute)
column_spec = dataspec.columns[core_condition.attribute]
if condition_type.HasField("na_condition"):
return IsMissingInCondition(attribute)
if condition_type.HasField("higher_condition"):
return NumericalHigherThanCondition(
attribute, condition_type.higher_condition.threshold,
core_condition.na_value)
if condition_type.HasField("true_value_condition"):
return IsTrueCondition(attribute, core_condition.na_value)
if condition_type.HasField("contains_bitmap_condition"):
items = column_spec_bitmap_to_items(
dataspec.columns[core_condition.attribute],
condition_type.contains_bitmap_condition.elements_bitmap)
if attribute.type == ColumnType.CATEGORICAL:
return CategoricalIsInCondition(attribute, items, core_condition.na_value)
elif attribute.type == ColumnType.CATEGORICAL_SET:
return CategoricalSetContainsCondition(attribute, items,
core_condition.na_value)
if condition_type.HasField("contains_condition"):
items = condition_type.contains_condition.elements
if not column_spec.categorical.is_already_integerized:
items = [
dataspec_lib.categorical_value_idx_to_value(column_spec, item)
for item in items
]
if attribute.type == ColumnType.CATEGORICAL:
return CategoricalIsInCondition(attribute, items, core_condition.na_value)
elif attribute.type == ColumnType.CATEGORICAL_SET:
return CategoricalSetContainsCondition(attribute, items,
core_condition.na_value)
if condition_type.HasField("discretized_higher_condition"):
threshold = dataspec_lib.discretized_numerical_to_numerical(
column_spec, condition_type.discretized_higher_condition.threshold)
return NumericalHigherThanCondition(attribute, threshold,
core_condition.na_value)
if condition_type.HasField("oblique_condition"):
attributes = [
dataspec_lib.make_simple_column_spec(dataspec, attribute_idx)
for attribute_idx in condition_type.oblique_condition.attributes
]
return NumericalSparseObliqueCondition(
attributes, list(condition_type.oblique_condition.weights),
condition_type.oblique_condition.threshold, core_condition.na_value)
raise ValueError(f"Non supported condition type: {core_condition}")