def build_lattice_layer()

in tensorflow_lattice/python/premade_lib.py [0:0]


def build_lattice_layer(lattice_input, feature_configs, model_config,
                        layer_output_range, submodel_index, is_inside_ensemble,
                        dtype):
  """Creates a `tfl.layers.Lattice` layer.

  Args:
    lattice_input: Input to the lattice layer.
    feature_configs: A list of `tfl.configs.FeatureConfig` instances that
      specify configurations for each feature.
    model_config: Model configuration object describing model architecture.
      Should be one of the model configs in `tfl.configs`.
    layer_output_range: A `tfl.premade_lib.LayerOutputRange` enum.
    submodel_index: Corresponding index into submodels.
    is_inside_ensemble: If this layer is inside an ensemble.
    dtype: dtype

  Returns:
    A `tfl.layers.Lattice` instance if `model_config.parameterization` is set to
    `'all_vertices'` or a `tfl.layers.KroneckerFactoredLattice` instance if
    set to `'kronecker_factored'`.

  Raises:
    ValueError: If `model_config.parameterization` is not one of
      `'all_vertices'` or `'kronecker_factored'`.
  """
  layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, submodel_index)

  (output_min, output_max, output_init_min,
   output_init_max) = _output_range(layer_output_range, model_config)

  feature_names = [feature_config.name for feature_config in feature_configs]
  lattice_sizes = [
      feature_config.lattice_size for feature_config in feature_configs
  ]
  lattice_monotonicities = _monotonicities_from_feature_configs(feature_configs)
  lattice_unimodalities = [
      feature_config.unimodality for feature_config in feature_configs
  ]
  lattice_regularizers = _lattice_regularizers(model_config,
                                               feature_configs) or None

  # Construct trust constraints within this lattice.
  edgeworth_trusts = []
  trapezoid_trusts = []
  for conditional_idx, conditional_feature_config in enumerate(feature_configs):
    for trust_config in conditional_feature_config.reflects_trust_in or []:
      if trust_config.feature_name in feature_names:
        main_idx = feature_names.index(trust_config.feature_name)
        if trust_config.trust_type == 'edgeworth':
          edgeworth_trusts.append(
              (main_idx, conditional_idx, trust_config.direction))
        elif trust_config.trust_type == 'trapezoid':
          trapezoid_trusts.append(
              (main_idx, conditional_idx, trust_config.direction))
        else:
          raise ValueError('Unrecognized trust type: {}'.format(
              trust_config.trust_type))
      elif is_inside_ensemble and trust_config.trust_type == 'trapezoid':
        logging.warning(
            'A "main" feature (%s) for a trapezoid trust constraint is not '
            'present in a lattice that includes the "conditional" feature '
            '(%s). In an ensemble model, this can result in constraint '
            'violations. Consider manually setting the ensemble structure if '
            'this constraint needs to be satisfied.', trust_config.feature_name,
            conditional_feature_config.name)

  monotonic_dominances = _dominance_constraints_from_feature_configs(
      feature_configs)

  if model_config.parameterization == 'all_vertices':
    layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, submodel_index)
    kernel_initializer = lattice_layer.LinearInitializer(
        lattice_sizes=lattice_sizes,
        monotonicities=lattice_monotonicities,
        unimodalities=lattice_unimodalities,
        output_min=output_init_min,
        output_max=output_init_max)
    return lattice_layer.Lattice(
        lattice_sizes=lattice_sizes,
        monotonicities=lattice_monotonicities,
        unimodalities=lattice_unimodalities,
        edgeworth_trusts=edgeworth_trusts,
        trapezoid_trusts=trapezoid_trusts,
        monotonic_dominances=monotonic_dominances,
        output_min=output_min,
        output_max=output_max,
        clip_inputs=False,
        interpolation=model_config.interpolation,
        kernel_regularizer=lattice_regularizers,
        kernel_initializer=kernel_initializer,
        dtype=dtype,
        name=layer_name)(
            lattice_input)
  elif model_config.parameterization == 'kronecker_factored':
    layer_name = '{}_{}'.format(KFL_LAYER_NAME, submodel_index)
    kernel_initializer = kfll.KFLRandomMonotonicInitializer(
        monotonicities=lattice_monotonicities,
        init_min=output_init_min,
        init_max=output_init_max,
        seed=model_config.random_seed)
    scale_initializer = kfll.ScaleInitializer(
        output_min=output_min, output_max=output_max)
    return kfll.KroneckerFactoredLattice(
        lattice_sizes=lattice_sizes[0],
        num_terms=model_config.num_terms,
        monotonicities=lattice_monotonicities,
        output_min=output_min,
        output_max=output_max,
        clip_inputs=False,
        kernel_initializer=kernel_initializer,
        scale_initializer=scale_initializer,
        dtype=dtype,
        name=layer_name)(
            lattice_input)
  else:
    raise ValueError('Unknown type of parameterization: {}'.format(
        model_config.parameterization))