def verify_hyperparameters()

in tensorflow_lattice/python/lattice_lib.py [0:0]


def verify_hyperparameters(lattice_sizes,
                           units=None,
                           weights_shape=None,
                           input_shape=None,
                           monotonicities=None,
                           unimodalities=None,
                           edgeworth_trusts=None,
                           trapezoid_trusts=None,
                           monotonic_dominances=None,
                           range_dominances=None,
                           joint_monotonicities=None,
                           joint_unimodalities=None,
                           output_min=None,
                           output_max=None,
                           regularization_amount=None,
                           regularization_info="",
                           interpolation="hypercube"):
  """Verifies that all given hyperparameters are consistent.

  This function does not inspect weights themselves. Only their shape. Use
  `assert_constraints()` to assert actual weights against constraints.

  See `tfl.layers.Lattice` class level comment for detailed description of
  arguments.

  Args:
    lattice_sizes: Lattice sizes to check againts.
    units: Units hyperparameter of `Lattice` layer.
    weights_shape: Shape of tensor which represents `Lattice` layer weights.
    input_shape: Shape of layer input. Useful only if `units` is set.
    monotonicities: Monotonicities hyperparameter of `Lattice` layer.
    unimodalities: Unimodalities hyperparameter of `Lattice` layer.
    edgeworth_trusts: Edgeworth_trusts hyperparameter of `Lattice` layer.
    trapezoid_trusts: Trapezoid_trusts hyperparameter of `Lattice` layer.
    monotonic_dominances: Monotonic dominances hyperparameter of `Lattice`
      layer.
    range_dominances: Range dominances hyperparameter of `Lattice` layer.
    joint_monotonicities: Joint monotonicities hyperparameter of `Lattice`
      layer.
    joint_unimodalities: Joint unimodalities hyperparameter of `Lattice` layer.
    output_min: Minimum output of `Lattice` layer.
    output_max: Maximum output of `Lattice` layer.
    regularization_amount: Regularization amount for regularizers.
    regularization_info: String which describes `regularization_amount`.
    interpolation: One of 'simplex' or 'hypercube' interpolation.

  Raises:
    ValueError: If something is inconsistent.
  """
  for size in lattice_sizes:
    if size < 2:
      raise ValueError("All lattice sizes must be at least 2. Given: %s" %
                       lattice_sizes)

  # It also raises errors if monotonicities specified incorrectly.
  monotonicities = utils.canonicalize_monotonicities(
      monotonicities, allow_decreasing=False)
  if monotonicities is not None:
    if len(monotonicities) != len(lattice_sizes):
      raise ValueError("If provided 'monotonicities' should have same number "
                       "of elements as 'lattice_sizes'. 'monotonicities': %s,"
                       "'lattice_sizes: %s" % (monotonicities, lattice_sizes))

  unimodalities = utils.canonicalize_unimodalities(unimodalities)
  if unimodalities is not None:
    if len(unimodalities) != len(lattice_sizes):
      raise ValueError("If provided 'unimodalities' should have same number "
                       "of elements as 'lattice_sizes'. 'unimodalities': %s, "
                       "'lattice_sizes: %s" % (unimodalities, lattice_sizes))
    for unimodality, dim_size in zip(unimodalities, lattice_sizes):
      if unimodality != 0 and dim_size < 3:
        raise ValueError("Unimodal dimensions must have lattice size at "
                         "least 3. unimodalities: %s, lattice_sizes: %s" %
                         (unimodalities, lattice_sizes))

  if monotonicities is not None and unimodalities is not None:
    for i, (monotonicity,
            unimodality) in enumerate(zip(monotonicities, unimodalities)):
      if monotonicity != 0 and unimodality != 0:
        raise ValueError("Both monotonicity and unimodality can not be set "
                         "simultaniously for same dimension. Dimension: %d, "
                         "'monotonicities': %s, 'unimodalities': %s" %
                         (i, monotonicities, unimodalities))

  all_trusts = utils.canonicalize_trust((edgeworth_trusts or []) +
                                        (trapezoid_trusts or [])) or []
  main_dims, cond_dims, trapezoid_cond_dims = set(), set(), set()
  dim_pairs_direction = {}
  for i, constraint in enumerate(all_trusts):
    main_dim, cond_dim, cond_direction = constraint
    if (main_dim >= len(lattice_sizes) or cond_dim >= len(lattice_sizes) or
        main_dim < 0 or cond_dim < 0):
      raise ValueError("Dimensions constrained by trust constraints "
                       "are not within the range of the lattice. "
                       "'trust_dims': %s, %s, num_dims: %s" %
                       (main_dim, cond_dim, len(lattice_sizes)))
    if not isinstance(main_dim, int) or not isinstance(cond_dim, int):
      raise ValueError("Trust constraint dimensions must be integers. Seeing "
                       "main_dim %s and cond_dim %s" % (main_dim, cond_dim))
    if monotonicities[main_dim] != 1:
      raise ValueError("Trust constraint's main feature must be "
                       "monotonic. Dimension %s is not monotonic." % (main_dim))
    if (main_dim, cond_dim) in dim_pairs_direction and dim_pairs_direction[
        (main_dim, cond_dim)] != cond_direction:
      raise ValueError("Cannot have two trust constraints on the same pair of "
                       "features in opposite directions. Features: %d, %d" %
                       (main_dim, cond_dim))
    # Only apply this check to trapezoid constraints when there are also
    # edgeworth constraints.
    if edgeworth_trusts and i >= len(edgeworth_trusts):
      if cond_dim in trapezoid_cond_dims:
        logging.warning(
            "Conditional dimension %d is being used in multiple trapezoid "
            "trust constraints. Because of this and the presence of edgeworth "
            "constraints, there may be slight trust violations of one or more "
            "of these constraints at the end of training. Consider increasing "
            "num_projection_iterations to reduce violation.", cond_dim)
      trapezoid_cond_dims.add(cond_dim)
    main_dims.add(main_dim)
    cond_dims.add(cond_dim)
    dim_pairs_direction[(main_dim, cond_dim)] = cond_direction
  main_and_cond = main_dims.intersection(cond_dims)
  if main_and_cond:
    raise ValueError("A feature cannot be both a main feature and a "
                     "conditional feature in trust constraints. "
                     "Seeing dimension %d in both" % (main_and_cond.pop()))

  if monotonic_dominances is not None:
    _verify_dominances_hyperparameters(monotonic_dominances, "monotonic",
                                       monotonicities, len(lattice_sizes))
  if range_dominances is not None:
    _verify_dominances_hyperparameters(range_dominances, "range",
                                       monotonicities, len(lattice_sizes))

  if joint_monotonicities is not None:
    for i, constraint in enumerate(joint_monotonicities):
      if len(constraint) != 2:
        raise ValueError("Joint monotonicities constraints must consist of 2 "
                         "elements. Seeing constraint tuple %s" % (constraint,))
      dim1, dim2 = constraint
      if (dim1 >= len(lattice_sizes) or dim2 >= len(lattice_sizes) or
          dim1 < 0 or dim2 < 0):
        raise ValueError("Dimensions constrained by joint monotonicity "
                         "constraints are not within the range of the lattice. "
                         "'dims': %s, %s, num_dims: %s" %
                         (dim1, dim2, len(lattice_sizes)))
      if not isinstance(dim1, int) or not isinstance(dim2, int):
        raise ValueError("Joint monotonicity constraint dimensions must be "
                         "integers. Seeing dimensions %s, %s" % (dim1, dim2))

  if joint_unimodalities is not None:
    for single_constraint in joint_unimodalities:
      dimensions, direction = single_constraint
      if (not isinstance(direction, six.string_types) or
          (direction.lower() != "valley" and direction.lower() != "peak")):
        raise ValueError("Joint unimodality tuple must end with string 'valley'"
                         " or 'peak' which represents unimodality direction. "
                         "Given: %s" % (single_constraint,))
      for dim in dimensions:
        if dim < 0 or dim >= len(lattice_sizes):
          raise ValueError("Dimension constrained by joint unimodality is not "
                           "within the range of the lattice. Joint unimodality "
                           "dimension: %s, total number of dimensions: "
                           "%s" % (dim, len(lattice_sizes)))
        if not isinstance(dim, int):
          raise ValueError("Joint unimodality constraint dimensions must be "
                           "integer. Seeing: %s" % dim)
        if lattice_sizes[dim] < 3:
          raise ValueError("Dimensions constrained for joint unimodality must "
                           "have lattice size at least 3. "
                           "Dim: %s has size: %s" % (dim, lattice_sizes[dim]))
        if monotonicities and monotonicities[dim] != 0:
          raise ValueError("Dimension %d constrained for joint_unimodalities "
                           "can not also by monotonic." % dim)
      dims_set = set(dimensions)
      if len(dims_set) != len(dimensions):
        raise ValueError("All dimensions within single joint unimodality "
                         "constraint must be distinct. "
                         "Given: %s" % single_constraint)

  if weights_shape is not None:
    if len(weights_shape) != 2:
      raise ValueError("Weights must have shape of rank-2. "
                       "Given: %s" % weights_shape)
    expected_num_weights = 1
    for dim_size in lattice_sizes:
      expected_num_weights *= dim_size
    if weights_shape[0] != expected_num_weights:
      raise ValueError("Number of elements in weights does not correspond to "
                       "lattice sizes. Weights shape: %s, lattice sizes: %s, "
                       "Number of elements defined by lattice sizes: %d" %
                       (weights_shape, lattice_sizes, expected_num_weights))

  if input_shape is not None:
    if not isinstance(input_shape, list):
      if input_shape[-1] != len(lattice_sizes):
        raise ValueError("Last dimension of input shape must have same number "
                         "of elements as 'lattice_sizes'. 'input shape': %s, "
                         "'lattice_sizes': %s" % (input_shape, lattice_sizes))
      shape = input_shape
    else:
      if len(input_shape) != len(lattice_sizes):
        raise ValueError("If lattice input is provided as list of tensors their"
                         " number must match lattice_sizes. 'input list': %s, "
                         "'lattice_sizes': %s" % (input_shape, lattice_sizes))
      shape = input_shape[0]
    if units is not None:  # FYI: It is inside "if input_shape is not None:"
      if units > 1 and (len(shape) < 3 or shape[-2] != units):
        raise ValueError("If 'units' > 1 then input shape of Lattice layer must"
                         " have rank at least 3 where second from last "
                         "dimension is equal to 'units'. 'units': %s, "
                         "input_shape: %s" % (units, input_shape))

  if output_min is not None and output_max is not None:
    if output_min >= output_max:
      raise ValueError("'output_min' must be not greater than 'output_max'. "
                       "'output_min': %f, 'output_max': %f" %
                       (output_min, output_max))

  if regularization_amount and isinstance(regularization_amount, (list, tuple)):
    if len(regularization_amount) != len(lattice_sizes):
      raise ValueError(
          "If %s losses are given per dimension their number must "
          "match number of dimensions defined by lattice sizes. "
          "l1: %s, lattice sizes: %s" %
          (regularization_info, regularization_amount, lattice_sizes))

  if interpolation not in ["hypercube", "simplex"]:
    raise ValueError("Lattice interpolation type should be either 'simplex' "
                     "or 'hypercube': %s" % interpolation)