in tensorflow_lattice/python/rtl_lib.py [0:0]
def verify_hyperparameters(lattice_size,
input_shape=None,
output_min=None,
output_max=None,
interpolation="hypercube",
parameterization="all_vertices",
kernel_initializer=None,
kernel_regularizer=None):
"""Verifies that all given hyperparameters are consistent.
See `tfl.layers.RTL` class level comment for detailed description of
arguments.
Args:
lattice_size: Lattice size to check againts.
input_shape: Shape of layer input.
output_min: Minimum output of `RTL` layer.
output_max: Maximum output of `RTL` layer.
interpolation: One of 'simplex' or 'hypercube' interpolation.
parameterization: One of 'all_vertices' or 'kronecker_factored'
parameterizations.
kernel_initializer: Initizlier to check against.
kernel_regularizer: Regularizers to check against.
Raises:
ValueError: If lattice_size < 2.
KeyError: If input_shape is a dict with incorrect keys.
ValueError: If output_min >= output_max.
ValueError: If interpolation is not one of 'simplex' or 'hypercube'.
ValueError: If parameterization is 'kronecker_factored' and
kernel_initializer is 'linear_initializer'.
ValueError: If parameterization is 'kronecker_factored' and
kernel_regularizer is not None.
ValueError: If kernel_regularizer contains a tuple with len != 3.
ValueError: If kernel_regularizer contains a tuple with non-float l1 value.
ValueError: If kernel_regularizer contains a tuple with non-flaot l2 value.
"""
if lattice_size < 2:
raise ValueError(
"Lattice size must be at least 2. Given: {}".format(lattice_size))
if input_shape:
if isinstance(input_shape, dict):
for key in input_shape:
if key not in ["unconstrained", "increasing"]:
raise KeyError("Input shape keys should be either 'unconstrained' "
"or 'increasing', but seeing: {}".format(key))
if output_min is not None and output_max is not None:
if output_min >= output_max:
raise ValueError("'output_min' must be not greater than 'output_max'. "
"'output_min': %f, 'output_max': %f" %
(output_min, output_max))
if interpolation not in ["hypercube", "simplex"]:
raise ValueError("RTL interpolation type should be either 'simplex' "
"or 'hypercube': %s" % interpolation)
if (parameterization == "kronecker_factored" and
kernel_initializer == "linear_initializer"):
raise ValueError("'kronecker_factored' parameterization does not currently "
"support linear iniitalization. 'parameterization': %s, "
"'kernel_initializer': %s" %
(parameterization, kernel_initializer))
if (parameterization == "kronecker_factored" and
kernel_regularizer is not None):
raise ValueError("'kronecker_factored' parameterization does not currently "
"support regularization. 'parameterization': %s, "
"'kernel_regularizer': %s" %
(parameterization, kernel_regularizer))
if kernel_regularizer:
if isinstance(kernel_regularizer, list):
regularizers = kernel_regularizer
if isinstance(kernel_regularizer[0], six.string_types):
regularizers = [kernel_regularizer]
for regularizer in regularizers:
if len(regularizer) != 3:
raise ValueError("Regularizer tuples/lists must have three elements "
"(type, l1, and l2). Given: {}".format(regularizer))
_, l1, l2 = regularizer
if not isinstance(l1, float):
raise ValueError(
"Regularizer l1 must be a single float. Given: {}".format(
type(l1)))
if not isinstance(l2, float):
raise ValueError(
"Regularizer l2 must be a single float. Given: {}".format(
type(l2)))