def _init_model()

in tfx/benchmarks/tfma_v2_benchmark_base.py [0:0]


  def _init_model(self, multi_model, validation):
    # The benchmark runner will instantiate this class twice - once to determine
    # the benchmarks to run, and once to actually to run them. However, Keras
    # freezes if we try to load the same model twice. As such, we have to pull
    # the model loading out of the constructor into a separate method which we
    # call before each benchmark.
    if multi_model:
      metric_specs = metric_specs_util.specs_from_metrics(
          [tf.keras.metrics.AUC(name="auc", num_thresholds=10000)],
          model_names=["candidate", "baseline"])
      if validation:
        # Only one metric, adding a threshold for all slices.
        metric_specs[0].metrics[0].threshold.CopyFrom(
            tfma.MetricThreshold(
                value_threshold=tfma.GenericValueThreshold(
                    lower_bound={"value": 0.5}, upper_bound={"value": 0.5}),
                change_threshold=tfma.GenericChangeThreshold(
                    absolute={"value": -0.001},
                    direction=tfma.MetricDirection.HIGHER_IS_BETTER)))
      self._eval_config = tfma.EvalConfig(
          model_specs=[
              tfma.ModelSpec(name="candidate", label_key="tips"),
              tfma.ModelSpec(
                  name="baseline", label_key="tips", is_baseline=True)
          ],
          metrics_specs=metric_specs)
      self._eval_shared_models = {
          "candidate":
              tfma.default_eval_shared_model(
                  self._dataset.trained_saved_model_path(),
                  eval_config=self._eval_config,
                  model_name="candidate"),
          "baseline":
              tfma.default_eval_shared_model(
                  self._dataset.trained_saved_model_path(),
                  eval_config=self._eval_config,
                  model_name="baseline")
      }
    else:
      metric_specs = metric_specs_util.specs_from_metrics(
          [tf.keras.metrics.AUC(name="auc", num_thresholds=10000)])
      if validation:
        # Only one metric, adding a threshold for all slices.
        metric_specs[0].metrics[0].threshold.CopyFrom(
            tfma.MetricThreshold(
                value_threshold=tfma.GenericValueThreshold(
                    lower_bound={"value": 0.5}, upper_bound={"value": 0.5})))
      self._eval_config = tfma.EvalConfig(
          model_specs=[tfma.ModelSpec(label_key="tips")],
          metrics_specs=metric_specs)
      self._eval_shared_models = {
          "":
              tfma.default_eval_shared_model(
                  self._dataset.trained_saved_model_path(),
                  eval_config=self._eval_config)
      }