def verify_hyperparameters()

in tensorflow_lattice/python/pwl_calibration_lib.py [0:0]


def verify_hyperparameters(input_keypoints=None,
                           output_min=None,
                           output_max=None,
                           monotonicity=None,
                           convexity=None,
                           is_cyclic=False,
                           lengths=None,
                           weights_shape=None,
                           input_keypoints_type=None):
  """Verifies that all given hyperparameters are consistent.

  See PWLCalibration class level comment for detailed description of arguments.

  Args:
    input_keypoints: `input_keypoints` of PWLCalibration layer.
    output_min: Smallest output of PWLCalibration layer.
    output_max: Largest output of PWLCalibration layer.
    monotonicity: `monotonicity` hyperparameter of PWLCalibration layer.
    convexity: `convexity` hyperparameter of PWLCalibration layer.
    is_cyclic: `is_cyclic` hyperparameter of PWLCalibration layer.
    lengths: Lengths of pieces of piecewise linear function.
    weights_shape: Shape of weights of PWLCalibration layer.
    input_keypoints_type: The type of input keypoints of a PWLCalibration layer.

  Raises:
    ValueError: If something is inconsistent.
  """
  if input_keypoints is not None:
    if tf.is_tensor(input_keypoints):
      if len(input_keypoints.shape) != 1 or input_keypoints.shape[0] < 2:
        raise ValueError("Input keypoints must be rank-1 tensor of size at "
                         "least 2. It is: " + str(input_keypoints))
    else:
      if len(input_keypoints) < 2:
        raise ValueError("At least 2 input keypoints must be provided. "
                         "Given: " + str(input_keypoints))
      if not all(input_keypoints[i] < input_keypoints[i + 1]
                 for i in range(len(input_keypoints) - 1)):
        raise ValueError("Keypoints must be strictly increasing. They are: " +
                         str(input_keypoints))

  if output_min is not None and output_max is not None:
    if output_max < output_min:
      raise ValueError("If specified output_max must be greater than "
                       "output_min. "
                       "They are: ({}, {})".format(output_min, output_max))

  # It also raises errors if monotonicities specified incorrectly.
  monotonicity = utils.canonicalize_monotonicity(monotonicity)
  convexity = utils.canonicalize_convexity(convexity)

  if is_cyclic and (monotonicity or convexity):
    raise ValueError("'is_cyclic' can not be specified together with "
                     "'monotonicity'({}) or 'convexity'({}).".format(
                         monotonicity, convexity))

  if weights_shape is not None:
    if len(weights_shape) != 2 or weights_shape[0] < 2:
      raise ValueError("PWLCalibrator weights must have shape: [k, units] where"
                       " k > 1. It is: " + str(weights_shape))

  if lengths is not None and weights_shape is not None:
    if tf.is_tensor(lengths):
      num_lengths = lengths.shape[0]
    else:
      num_lengths = len(lengths)
    if num_lengths + 1 != weights_shape[0]:
      raise ValueError("Number of lengths must be equal to number of weights "
                       "minus one. Lengths: %s, weights_shape: %s" %
                       (lengths, weights_shape))

  if (input_keypoints_type is not None and input_keypoints_type != "fixed" and
      input_keypoints_type != "learned_interior"):
    raise ValueError(
        "input_keypoints_type must be one of 'fixed' or 'learned_interior': %s"
        % input_keypoints_type)