def validate_min_diff_structure()

in tensorflow_model_remediation/min_diff/keras/utils/structure_utils.py [0:0]


def validate_min_diff_structure(struct,
                                struct_name="struct",
                                element_type=None):
  # pyformat: disable
  """Validates that `struct` is a valid MinDiff structure.

  Arguments:
    struct: Structure that will be validated.
    struct_name: Name of struct, used only for error messages.
    element_type: Type of elements. If `None`, types are not checked.

  A `struct` is a valid MinDiff structure if it is either a single element
  (including tuples) or is an unnested dictionary (with string keys). If
  `element_type` is set, the function will also validate that all elements are
  of the correct type.

  Raises:
    TypeError: If `struct` is neither a single element (including a tuple) nor a
      dict.
    ValueError: If `struct` is a dict with non-string keys.
    ValueError: If `struct` is a dict with values that are not single elements
      (including tuples).
  """
  # pyformat: enable
  if _is_min_diff_element(struct, element_type):
    return  # Valid single MinDiff element.

  err_msg = "`{}` is not a recognized MinDiff structure.".format(struct_name)
  # If struct is not a min_diff_element, then it should be a dict. If not, raise
  # an error.
  if not isinstance(struct, dict):
    accepted_types_msg = "a single unnested element (or tuple)"
    if element_type is not None:
      accepted_types_msg += " of type {}".format(element_type)
    accepted_types_msg += ", or a dict"
    raise TypeError("{} It should have a type of one of: {}. Given: {}".format(
        err_msg, accepted_types_msg, type(struct)))

  # Validate that keys are strings if struct is a dict.
  if not all([isinstance(key, str) for key in struct.keys()]):
    raise ValueError(
        "{} If `{}` is a dict, it must contain only string keys, given keys: {}"
        .format(err_msg, struct_name, list(struct.keys())))

  # Validate that values are all single MinDiff elements.
  if not all([
      _is_min_diff_element(element, element_type)
      for element in struct.values()
  ]):
    err_msg += "If it is a dict, it must be unnested"
    if element_type is not None:
      err_msg += " and contain only elements of type {}".format(element_type)
    raise ValueError("{}. Given: {}".format(err_msg, struct))