def build_multi_unit_calibration_layers()

in tensorflow_lattice/python/premade_lib.py [0:0]


def build_multi_unit_calibration_layers(calibration_input_layer,
                                        calibration_output_units, model_config,
                                        layer_output_range,
                                        output_single_tensor, dtype):
  """Creates a mapping from feature names to calibration outputs.

  Args:
    calibration_input_layer: A mapping from feature name to `tf.keras.Input`.
    calibration_output_units: A mapping from feature name to units.
    model_config: Model configuration object describing model architecture.
      Should be one of the model configs in `tfl.configs`.
    layer_output_range: A `tfl.premade_lib.LayerOutputRange` enum.
    output_single_tensor: If output for each feature should be a single tensor.
    dtype: dtype

  Returns:
    A mapping from feature name to calibration output Tensors.
  """
  calibration_output = {}
  for feature_name, units in calibration_output_units.items():
    if units == 0:
      raise ValueError(
          'Feature {} is not used. Calibration output units is 0.'.format(
              feature_name))
    feature_config = model_config.feature_config_by_name(feature_name)
    calibration_input = calibration_input_layer[feature_name]
    layer_name = '{}_{}'.format(CALIB_LAYER_NAME, feature_name)

    (output_min, output_max, output_init_min,
     output_init_max) = _output_range(layer_output_range, model_config,
                                      feature_config)

    if feature_config.num_buckets:
      kernel_initializer = tf.keras.initializers.RandomUniform(
          output_init_min, output_init_max)
      calibrated = (
          categorical_calibration_layer.CategoricalCalibration(
              num_buckets=feature_config.num_buckets,
              units=units,
              output_min=output_min,
              output_max=output_max,
              kernel_initializer=kernel_initializer,
              monotonicities=feature_config.monotonicity if isinstance(
                  feature_config.monotonicity, list) else None,
              default_input_value=feature_config.default_value,
              split_outputs=(units > 1 and not output_single_tensor),
              dtype=dtype,
              name=layer_name)(calibration_input))
    else:
      kernel_regularizer = _input_calibration_regularizers(
          model_config, feature_config)
      monotonicity = feature_config.monotonicity
      if (utils.canonicalize_monotonicity(monotonicity) == 0 and
          feature_config.pwl_calibration_always_monotonic):
        monotonicity = 1
      kernel_initializer = pwl_calibration_layer.UniformOutputInitializer(
          output_min=output_init_min,
          output_max=output_init_max,
          monotonicity=monotonicity,
          keypoints=feature_config.pwl_calibration_input_keypoints)
      calibrated = (
          pwl_calibration_layer.PWLCalibration(
              units=units,
              input_keypoints=feature_config.pwl_calibration_input_keypoints,
              output_min=output_min,
              output_max=output_max,
              clamp_min=feature_config.pwl_calibration_clamp_min,
              clamp_max=feature_config.pwl_calibration_clamp_max,
              missing_input_value=feature_config.default_value,
              impute_missing=(feature_config.default_value is not None),
              kernel_initializer=kernel_initializer,
              kernel_regularizer=kernel_regularizer,
              monotonicity=monotonicity,
              convexity=feature_config.pwl_calibration_convexity,
              split_outputs=(units > 1 and not output_single_tensor),
              input_keypoints_type=feature_config
              .pwl_calibration_input_keypoints_type,
              dtype=dtype,
              name=layer_name)(calibration_input))
    if output_single_tensor:
      calibration_output[feature_name] = calibrated
    elif units == 1:
      calibration_output[feature_name] = [calibrated]
    else:
      # calibrated will have already been split in this case.
      calibration_output[feature_name] = calibrated
  return calibration_output