def verify_hyperparameters()

in tensorflow_lattice/python/kronecker_factored_lattice_lib.py [0:0]


def verify_hyperparameters(lattice_sizes=None,
                           units=None,
                           num_terms=None,
                           input_shape=None,
                           monotonicities=None,
                           output_min=None,
                           output_max=None):
  """Verifies that all given hyperparameters are consistent.

  This function does not inspect weights themselves. Only their shape. Use
  `assert_constraints()` to assert actual weights against constraints.

  See `tfl.layers.KroneckerFactoredLattice` class level comment for detailed
  description of arguments.

  Args:
    lattice_sizes: Lattice size to check against.
    units: Units hyperparameter of `KroneckerFactoredLattice` layer.
    num_terms: Number of independently trained submodels hyperparameter of
      `KroneckerFactoredLattice` layer.
    input_shape: Shape of layer input. Useful only if `units` and/or
      `monotonicities` is set.
    monotonicities: Monotonicities hyperparameter of `KroneckerFactoredLattice`
      layer. Useful only if `input_shape` is set.
    output_min: Minimum output of `KroneckerFactoredLattice` layer.
    output_max: Maximum output of `KroneckerFactoredLattice` layer.

  Raises:
    ValueError: If lattice_sizes < 2.
    ValueError: If units < 1.
    ValueError: If num_terms < 1.
    ValueError: If len(monotonicities) does not match number of inputs.
  """
  if lattice_sizes and lattice_sizes < 2:
    raise ValueError("Lattice size must be at least 2. Given: %s" %
                     lattice_sizes)

  if units and units < 1:
    raise ValueError("Units must be at least 1. Given: %s" % units)

  if num_terms and num_terms < 1:
    raise ValueError("Number of terms must be at least 1. Given: %s" %
                     num_terms)

  # input_shape: (batch, ..., units, dims)
  if input_shape:
    # It also raises errors if monotonicities is specified incorrectly.
    monotonicities = utils.canonicalize_monotonicities(
        monotonicities, allow_decreasing=False)
    # Extract shape to check units and dims to check monotonicity
    if isinstance(input_shape, list):
      dims = len(input_shape)
      # Check monotonicity.
      if monotonicities and len(monotonicities) != dims:
        raise ValueError("If input is provided as list of tensors, their number"
                         " must match monotonicities. 'input_list': %s, "
                         "'monotonicities': %s" % (input_shape, monotonicities))
      shape = input_shape[0]
    else:
      dims = input_shape.as_list()[-1]
      # Check monotonicity.
      if monotonicities and len(monotonicities) != dims:
        raise ValueError("Last dimension of input shape must have same number "
                         "of elements as 'monotonicities'. 'input shape': %s, "
                         "'monotonicities': %s" % (input_shape, monotonicities))
      shape = input_shape
    if units and units > 1 and (len(shape) < 3 or shape[-2] != units):
      raise ValueError("If 'units' > 1 then input shape of "
                       "KroneckerFactoredLattice layer must have rank at least "
                       "3 where the second from the last dimension is equal to "
                       "'units'. 'units': %s, 'input_shape: %s" %
                       (units, input_shape))

  if output_min is not None and output_max is not None:
    if output_min >= output_max:
      raise ValueError("'output_min' must be strictly less than 'output_max'. "
                       "'output_min': %f, 'output_max': %f" %
                       (output_min, output_max))