in tensorflow_model_remediation/min_diff/keras/utils/structure_utils.py [0:0]
def validate_min_diff_structure(struct,
struct_name="struct",
element_type=None):
# pyformat: disable
"""Validates that `struct` is a valid MinDiff structure.
Arguments:
struct: Structure that will be validated.
struct_name: Name of struct, used only for error messages.
element_type: Type of elements. If `None`, types are not checked.
A `struct` is a valid MinDiff structure if it is either a single element
(including tuples) or is an unnested dictionary (with string keys). If
`element_type` is set, the function will also validate that all elements are
of the correct type.
Raises:
TypeError: If `struct` is neither a single element (including a tuple) nor a
dict.
ValueError: If `struct` is a dict with non-string keys.
ValueError: If `struct` is a dict with values that are not single elements
(including tuples).
"""
# pyformat: enable
if _is_min_diff_element(struct, element_type):
return # Valid single MinDiff element.
err_msg = "`{}` is not a recognized MinDiff structure.".format(struct_name)
# If struct is not a min_diff_element, then it should be a dict. If not, raise
# an error.
if not isinstance(struct, dict):
accepted_types_msg = "a single unnested element (or tuple)"
if element_type is not None:
accepted_types_msg += " of type {}".format(element_type)
accepted_types_msg += ", or a dict"
raise TypeError("{} It should have a type of one of: {}. Given: {}".format(
err_msg, accepted_types_msg, type(struct)))
# Validate that keys are strings if struct is a dict.
if not all([isinstance(key, str) for key in struct.keys()]):
raise ValueError(
"{} If `{}` is a dict, it must contain only string keys, given keys: {}"
.format(err_msg, struct_name, list(struct.keys())))
# Validate that values are all single MinDiff elements.
if not all([
_is_min_diff_element(element, element_type)
for element in struct.values()
]):
err_msg += "If it is a dict, it must be unnested"
if element_type is not None:
err_msg += " and contain only elements of type {}".format(element_type)
raise ValueError("{}. Given: {}".format(err_msg, struct))