def _flatten_min_diff_structure()

in tensorflow_model_remediation/min_diff/keras/utils/structure_utils.py [0:0]


def _flatten_min_diff_structure(struct, run_validation=False):
  # pyformat: disable
  """Flattens a MinDiff structure after optionally validating it.

  Arguments:
    struct: structure to be flattened. Must be a single element (including a
      tuple) or an unnested dict.
    run_validation: Boolean indicating whether to run validation. If `True`
      `validate_min_diff_structure` will be called on `struct`.

  Has similar behavior to `tf.nest.flatten` with the exception that tuples will
  be considered as single elements instead of structures to be flattened. See
  `tf.nest.flatten` documentation for additional details on behavior.

  Returns:
    A Python list, the flattened version of the input.

  Raises:
    ValueError: If struct is not a valid MinDiff structure (a single element
      including a tuple, or a dict).
  """
  # pyformat: enable

  if run_validation:
    validate_min_diff_structure(struct)

  if isinstance(struct, dict):
    return [struct[key] for key in sorted(struct.keys())]

  return [struct]  # Wrap in a list if not nested.