def verify_hyperparameters()

in tensorflow_lattice/python/linear_lib.py [0:0]


def verify_hyperparameters(num_input_dims=None,
                           units=None,
                           input_shape=None,
                           monotonicities=None,
                           monotonic_dominances=None,
                           range_dominances=None,
                           input_min=None,
                           input_max=None,
                           weights_shape=None):
  """Verifies that all given hyperparameters are consistent.

  This function does not inspect weights themselves. Only their shape. Use
  `assert_constraints()` to assert actual weights against constraints.

  Unlike linear layer itself this function requires monotonicites to be
  specified via list or tuple rather than via single element because that's how
  monotonicites are stored internaly.

  See `tfl.layers.Linear` Layer class level comment for detailed description of
  arguments.

  Args:
    num_input_dims: None or number of input dimensions.
    units: Units hyperparameter of Linear layer.
    input_shape: Shape of layer input.
    monotonicities: List or tuple of same length as number of elements in
      `weights` of {-1, 0, 1} which represent monotonicity constraints per
      dimension. -1 stands for decreasing, 0 for no constraints, 1 for
      increasing.
    monotonic_dominances: List of two-element tuples. First element is the index
      of the dominant feature. Second element is the index of the weak feature.
    range_dominances: List of two-element tuples. First element is the index of
      the dominant feature. Second element is the index of the weak feature.
    input_min: List or tuple of length same length as number of elements in
      'weights' of either None or float which specifies the minimum value to
      clip by.
    input_max: List or tuple of length same length as number of elements in
      'weights' of either None or float which specifies the maximum value to
      clip by.
    weights_shape: None or shape of tensor which represents weights of Linear
      layer.

  Raises:
    ValueError: If something is inconsistent.
  """
  # It also raises errors if monotonicities specified incorrectly.
  monotonicities = utils.canonicalize_monotonicities(monotonicities)
  input_min = utils.canonicalize_input_bounds(input_min)
  input_max = utils.canonicalize_input_bounds(input_max)

  if monotonicities is not None and num_input_dims is not None:
    if len(monotonicities) != num_input_dims:
      raise ValueError("Number of elements in 'monotonicities' must be equal to"
                       " num_input_dims. monotoniticites: %s, "
                       "len(monotonicities): %d, num_input_dims: %d" %
                       (monotonicities, len(monotonicities), num_input_dims))

  if weights_shape is not None:
    if len(weights_shape) != 2:
      raise ValueError("Expect weights to be a rank 2 tensor. Weights shape: "
                       "%s" % (weights_shape,))
    if monotonicities is not None and weights_shape[0] != len(monotonicities):
      raise ValueError("Number of elements in 'monotonicities' does not "
                       "correspond to number of weights. Weights shape: %s, "
                       "monotonicities: %s" % (weights_shape, monotonicities))
    if input_min is not None and weights_shape[0] != len(input_min):
      raise ValueError(
          "Number of elements in 'input_min' does not correspond "
          "to number of weights. Weights shape: %s, input_min: %s" %
          (weights_shape, input_min))
    if input_max is not None and weights_shape[0] != len(input_max):
      raise ValueError(
          "Number of elements in 'input_max' does not correspond "
          "to number of weights. Weights shape: %s, input_max: %s" %
          (weights_shape, input_max))

  if input_shape is not None:
    assert units is not None and num_input_dims is not None
    if (units > 1 and
        (len(input_shape) != 3 or input_shape[1] != units or
         input_shape[2] != num_input_dims)):
      raise ValueError("'input_shape' must be of rank three and number of "
                       "elements of second and third dimensions must be "
                       "equal to 'units' and 'num_input_dims' respectively. "
                       "'input_shape': " + str(input_shape) + "'units': " +
                       str(units) + "'num_input_dims': " + str(num_input_dims))
    elif (units == 1 and
          (len(input_shape) != 2 or input_shape[1] != num_input_dims)):
      raise ValueError("'input_shape' must be of rank two and number of "
                       "elements of second dimension must be equal to "
                       "'num_input_dims'. 'input_shape': " + str(input_shape) +
                       "'num_input_dims': " + str(num_input_dims))

  for dim, (lower, upper) in enumerate(zip(input_min or [], input_max or [])):
    if lower is not None and upper is not None and lower > upper:
      raise ValueError("Cannot have 'input_min' greater than 'input_max'."
                       "Dimension: %d, input_min[%d]: %f, input_max[%d]: %f" %
                       (dim, dim, input_min[dim], dim, input_max[dim]))

  if monotonic_dominances is not None:
    assert monotonicities is not None
    num_input_dims = len(monotonicities)
    dim_pairs = set()
    for constraint in monotonic_dominances:
      if len(constraint) != 2:
        raise ValueError("Monotonic dominance constraints must consist of 2 "
                         "elements. Seeing constraint tuple %s" % (constraint,))
      dominant_dim, weak_dim = constraint
      if (dominant_dim >= num_input_dims or weak_dim >= num_input_dims or
          dominant_dim < 0 or weak_dim < 0):
        raise ValueError("Dimensions constrained by monotonic dominance "
                         "constraints are not within the input dimensions. "
                         "'dims': %s, %s, num_dims: %s" %
                         (dominant_dim, weak_dim, num_input_dims))
      if not isinstance(dominant_dim, int) or not isinstance(weak_dim, int):
        raise ValueError("Monotonic dominance constraint dimensions must be "
                         "integers. Seeing dominant_dim %s and weak_dim %s" %
                         (dominant_dim, weak_dim))
      for dim in [dominant_dim, weak_dim]:
        if monotonicities[dim] != 1:
          raise ValueError("Monotonic dominance constraint's dimensions must "
                           "be monotonic. Dimension %d is not monotonic." %
                           (dim))
      if (weak_dim, dominant_dim) in dim_pairs:
        raise ValueError("Cannot have two monotonic dominance constraints on "
                         "the same pair of features conflicting. Features: %d, "
                         "%d" % (dominant_dim, weak_dim))
      dim_pairs.add((dominant_dim, weak_dim))

  if range_dominances is not None:
    assert monotonicities is not None
    num_input_dims = len(monotonicities)
    dim_pairs = set()
    for constraint in range_dominances:
      if len(constraint) != 2:
        raise ValueError("Range dominance constraints must consist of 2 "
                         "elements. Seeing constraint tuple %s" % (constraint,))
      dominant_dim, weak_dim = constraint
      if (dominant_dim >= num_input_dims or weak_dim >= num_input_dims or
          dominant_dim < 0 or weak_dim < 0):
        raise ValueError("Dimensions constrained by range dominance "
                         "constraints are not within the input dimensions. "
                         "'dims': %s, %s, num_dims: %s" %
                         (dominant_dim, weak_dim, num_input_dims))
      if not isinstance(dominant_dim, int) or not isinstance(weak_dim, int):
        raise ValueError("Range dominance constraint dimensions must be "
                         "integers. Seeing dominant_dim %s and weak_dim %s" %
                         (dominant_dim, weak_dim))
      if (monotonicities[dominant_dim] != monotonicities[weak_dim] or
          monotonicities[dominant_dim] == 0):
        raise ValueError("Range dominance constraint's dimensions must have "
                         "the same direction of monotonicity. Dimension %d is "
                         "%d. Dimension %d is %d." %
                         (dominant_dim, monotonicities[dominant_dim], weak_dim,
                          monotonicities[weak_dim]))
      for dim in [dominant_dim, weak_dim]:
        if input_min is None or input_min[dim] is None:
          raise ValueError("Range dominance constraint's dimensions must "
                           "have `input_min` set. Dimension %d is not set." %
                           (dim))
        if input_max is None or input_max[dim] is None:
          raise ValueError("Range dominance constraint's dimensions must "
                           "have `input_max` set. Dimension %d is not set." %
                           (dim))
      if (weak_dim, dominant_dim) in dim_pairs:
        raise ValueError("Cannot have two range dominance constraints on the "
                         "same pair of features conflicting. Features: %d, %d" %
                         (dominant_dim, weak_dim))
      dim_pairs.add((dominant_dim, weak_dim))

  if range_dominances is not None and monotonic_dominances is not None:
    monotonic_dominance_dims = set()
    for dims in monotonic_dominances:
      for dim in dims:
        monotonic_dominance_dims.add(dim)
    for dims in range_dominances:
      for dim in dims:
        if dim in monotonic_dominance_dims:
          raise ValueError("Cannot have both monotonic and range dominance "
                           "constraints specified on the same dimension. "
                           "Dimension %d is set by both." % (dim))