in tensorflow_lattice/python/kronecker_factored_lattice_lib.py [0:0]
def verify_hyperparameters(lattice_sizes=None,
units=None,
num_terms=None,
input_shape=None,
monotonicities=None,
output_min=None,
output_max=None):
"""Verifies that all given hyperparameters are consistent.
This function does not inspect weights themselves. Only their shape. Use
`assert_constraints()` to assert actual weights against constraints.
See `tfl.layers.KroneckerFactoredLattice` class level comment for detailed
description of arguments.
Args:
lattice_sizes: Lattice size to check against.
units: Units hyperparameter of `KroneckerFactoredLattice` layer.
num_terms: Number of independently trained submodels hyperparameter of
`KroneckerFactoredLattice` layer.
input_shape: Shape of layer input. Useful only if `units` and/or
`monotonicities` is set.
monotonicities: Monotonicities hyperparameter of `KroneckerFactoredLattice`
layer. Useful only if `input_shape` is set.
output_min: Minimum output of `KroneckerFactoredLattice` layer.
output_max: Maximum output of `KroneckerFactoredLattice` layer.
Raises:
ValueError: If lattice_sizes < 2.
ValueError: If units < 1.
ValueError: If num_terms < 1.
ValueError: If len(monotonicities) does not match number of inputs.
"""
if lattice_sizes and lattice_sizes < 2:
raise ValueError("Lattice size must be at least 2. Given: %s" %
lattice_sizes)
if units and units < 1:
raise ValueError("Units must be at least 1. Given: %s" % units)
if num_terms and num_terms < 1:
raise ValueError("Number of terms must be at least 1. Given: %s" %
num_terms)
# input_shape: (batch, ..., units, dims)
if input_shape:
# It also raises errors if monotonicities is specified incorrectly.
monotonicities = utils.canonicalize_monotonicities(
monotonicities, allow_decreasing=False)
# Extract shape to check units and dims to check monotonicity
if isinstance(input_shape, list):
dims = len(input_shape)
# Check monotonicity.
if monotonicities and len(monotonicities) != dims:
raise ValueError("If input is provided as list of tensors, their number"
" must match monotonicities. 'input_list': %s, "
"'monotonicities': %s" % (input_shape, monotonicities))
shape = input_shape[0]
else:
dims = input_shape.as_list()[-1]
# Check monotonicity.
if monotonicities and len(monotonicities) != dims:
raise ValueError("Last dimension of input shape must have same number "
"of elements as 'monotonicities'. 'input shape': %s, "
"'monotonicities': %s" % (input_shape, monotonicities))
shape = input_shape
if units and units > 1 and (len(shape) < 3 or shape[-2] != units):
raise ValueError("If 'units' > 1 then input shape of "
"KroneckerFactoredLattice layer must have rank at least "
"3 where the second from the last dimension is equal to "
"'units'. 'units': %s, 'input_shape: %s" %
(units, input_shape))
if output_min is not None and output_max is not None:
if output_min >= output_max:
raise ValueError("'output_min' must be strictly less than 'output_max'. "
"'output_min': %f, 'output_max': %f" %
(output_min, output_max))