def evaluate_model()

in fairness_indicators/example_model.py [0:0]


def evaluate_model(classifier, validate_tf_file, tfma_eval_result_path,
                   selected_slice, label, feature_map):
  """Evaluate Model using Tensorflow Model Analysis.

  Args:
    classifier: Trained classifier model to be evaluted.
    validate_tf_file: File containing validation TFRecordDataset.
    tfma_eval_result_path: Directory path where eval results will be written.
    selected_slice: Feature for slicing the data.
    label: Groundtruth label.
    feature_map: Dict of feature names to their data type.
  """

  def eval_input_receiver_fn():
    """Eval Input Receiver function."""
    serialized_tf_example = tf.compat.v1.placeholder(
        dtype=tf.string, shape=[None], name='input_example_placeholder')

    receiver_tensors = {'examples': serialized_tf_example}

    features = tf.io.parse_example(serialized_tf_example, feature_map)
    features['weight'] = tf.ones_like(features[label])

    return tfma.export.EvalInputReceiver(
        features=features,
        receiver_tensors=receiver_tensors,
        labels=features[label])

  tfma_export_dir = tfma.export.export_eval_savedmodel(
      estimator=classifier,
      export_dir_base=os.path.join(tempfile.gettempdir(), 'tfma_eval_model'),
      eval_input_receiver_fn=eval_input_receiver_fn)

  # Define slices that you want the evaluation to run on.
  slice_spec = [
      tfma.slicer.SingleSliceSpec(),  # Overall slice
      tfma.slicer.SingleSliceSpec(columns=[selected_slice]),
  ]

  # Add the fairness metrics.
  # pytype: disable=module-attr
  add_metrics_callbacks = [
      tfma.post_export_metrics.fairness_indicators(
          thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)
  ]
  # pytype: enable=module-attr

  eval_shared_model = tfma.default_eval_shared_model(
      eval_saved_model_path=tfma_export_dir,
      add_metrics_callbacks=add_metrics_callbacks)

  # Run the fairness evaluation.
  tfma.run_model_analysis(
      eval_shared_model=eval_shared_model,
      data_location=validate_tf_file,
      output_path=tfma_eval_result_path,
      slice_spec=slice_spec)