def _verify_ensemble_config()

in tensorflow_lattice/python/premade_lib.py [0:0]


def _verify_ensemble_config(model_config):
  """Verifies that an ensemble model and feature configs are properly specified.

  Args:
    model_config: Model configuration object describing model architecture.
      Should be one of the model configs in `tfl.configs`.

  Raises:
    ValueError: If `model_config.lattices` is set to 'rtl_layer' and
      `model_config.num_lattices` is not specified.
    ValueError: If `model_config.num_lattices < 2`.
    ValueError: If `model_config.lattices` is set to 'rtl_layer' and
      `lattice_size` is not the same for all features.
    ValueError: If `model_config.lattices` is set to 'rtl_layer' and
      there are features with unimodality constraints.
    ValueError: If `model_config.lattices` is set to 'rtl_layer' and
      there are features with trust constraints.
    ValueError: If `model_config.lattices` is set to 'rtl_layer' and
      there are features with dominance constraints.
    ValueError: If `model_config.lattices` is set to 'rtl_layer' and
      there are per-feature lattice regularizers.
    ValueError: If `model_config.lattices` is not iterable or constaints
      non-string values.
    ValueError: If `model_config.lattices` is not set to 'rtl_layer' or a fully
      specified list of lists of feature names.
  """
  if model_config.lattices == 'rtl_layer':
    # RTL must have num_lattices specified and >= 2.
    if model_config.num_lattices is None:
      raise ValueError('model_config.num_lattices must be specified when '
                       'model_config.lattices is set to \'rtl_layer\'.')
    if model_config.num_lattices < 2:
      raise ValueError(
          'CalibratedLatticeEnsemble must have >= 2 lattices. For single '
          'lattice models, use CalibratedLattice instead.')
    # Check that all lattices sizes for all features are the same.
    if any(feature_config.lattice_size !=
           model_config.feature_configs[0].lattice_size
           for feature_config in model_config.feature_configs):
      raise ValueError('RTL Layer must have the same lattice size for all '
                       'features.')
    # Check that there are only monotonicity and bound constraints.
    if any(
        feature_config.unimodality != 'none' and feature_config.unimodality != 0
        for feature_config in model_config.feature_configs):
      raise ValueError(
          'RTL Layer does not currently support unimodality constraints.')
    if any(feature_config.reflects_trust_in is not None
           for feature_config in model_config.feature_configs):
      raise ValueError(
          'RTL Layer does not currently support trust constraints.')
    if any(feature_config.dominates is not None
           for feature_config in model_config.feature_configs):
      raise ValueError(
          'RTL Layer does not currently support dominance constraints.')
    # Check that there are no per-feature lattice regularizers.
    for feature_config in model_config.feature_configs:
      for regularizer_config in feature_config.regularizer_configs or []:
        if not regularizer_config.name.startswith(
            _INPUT_CALIB_REGULARIZER_PREFIX):
          raise ValueError(
              'RTL Layer does not currently support per-feature lattice '
              'regularizers.')
  elif isinstance(model_config.lattices, list):
    # Make sure there are more than one lattice. If not, tell user to use
    # CalibratedLattice instead.
    if len(model_config.lattices) < 2:
      raise ValueError(
          'CalibratedLatticeEnsemble must have >= 2 lattices. For single '
          'lattice models, use CalibratedLattice instead.')
    for lattice in model_config.lattices:
      if (not np.iterable(lattice) or
          any(not isinstance(x, str) for x in lattice)):
        raise ValueError(
            'Lattices are not fully specified for ensemble config.')
  else:
    raise ValueError(
        'Lattices are not fully specified for ensemble config. Lattices must '
        'be set to \'rtl_layer\' or be fully specified as a list of lists of '
        'feature names.')