in tensorflow_model_remediation/min_diff/keras/utils/structure_utils.py [0:0]
def _flatten_min_diff_structure(struct, run_validation=False):
# pyformat: disable
"""Flattens a MinDiff structure after optionally validating it.
Arguments:
struct: structure to be flattened. Must be a single element (including a
tuple) or an unnested dict.
run_validation: Boolean indicating whether to run validation. If `True`
`validate_min_diff_structure` will be called on `struct`.
Has similar behavior to `tf.nest.flatten` with the exception that tuples will
be considered as single elements instead of structures to be flattened. See
`tf.nest.flatten` documentation for additional details on behavior.
Returns:
A Python list, the flattened version of the input.
Raises:
ValueError: If struct is not a valid MinDiff structure (a single element
including a tuple, or a dict).
"""
# pyformat: enable
if run_validation:
validate_min_diff_structure(struct)
if isinstance(struct, dict):
return [struct[key] for key in sorted(struct.keys())]
return [struct] # Wrap in a list if not nested.