def core_condition_to_condition()

in tensorflow_decision_forests/component/py_tree/condition.py [0:0]


def core_condition_to_condition(
    core_condition: decision_tree_pb2.NodeCondition,
    dataspec: data_spec_pb2.DataSpecification) -> AbstractCondition:
  """Converts a condition from the core to python format."""

  condition_type = core_condition.condition
  attribute = dataspec_lib.make_simple_column_spec(dataspec,
                                                   core_condition.attribute)
  column_spec = dataspec.columns[core_condition.attribute]

  if condition_type.HasField("na_condition"):
    return IsMissingInCondition(attribute)

  if condition_type.HasField("higher_condition"):
    return NumericalHigherThanCondition(
        attribute, condition_type.higher_condition.threshold,
        core_condition.na_value)

  if condition_type.HasField("true_value_condition"):
    return IsTrueCondition(attribute, core_condition.na_value)

  if condition_type.HasField("contains_bitmap_condition"):
    items = column_spec_bitmap_to_items(
        dataspec.columns[core_condition.attribute],
        condition_type.contains_bitmap_condition.elements_bitmap)
    if attribute.type == ColumnType.CATEGORICAL:
      return CategoricalIsInCondition(attribute, items, core_condition.na_value)
    elif attribute.type == ColumnType.CATEGORICAL_SET:
      return CategoricalSetContainsCondition(attribute, items,
                                             core_condition.na_value)

  if condition_type.HasField("contains_condition"):
    items = condition_type.contains_condition.elements
    if not column_spec.categorical.is_already_integerized:
      items = [
          dataspec_lib.categorical_value_idx_to_value(column_spec, item)
          for item in items
      ]
    if attribute.type == ColumnType.CATEGORICAL:
      return CategoricalIsInCondition(attribute, items, core_condition.na_value)
    elif attribute.type == ColumnType.CATEGORICAL_SET:
      return CategoricalSetContainsCondition(attribute, items,
                                             core_condition.na_value)

  if condition_type.HasField("discretized_higher_condition"):
    threshold = dataspec_lib.discretized_numerical_to_numerical(
        column_spec, condition_type.discretized_higher_condition.threshold)
    return NumericalHigherThanCondition(attribute, threshold,
                                        core_condition.na_value)

  if condition_type.HasField("oblique_condition"):
    attributes = [
        dataspec_lib.make_simple_column_spec(dataspec, attribute_idx)
        for attribute_idx in condition_type.oblique_condition.attributes
    ]
    return NumericalSparseObliqueCondition(
        attributes, list(condition_type.oblique_condition.weights),
        condition_type.oblique_condition.threshold, core_condition.na_value)

  raise ValueError(f"Non supported condition type: {core_condition}")