in tfx/benchmarks/tfma_v2_benchmark_base.py [0:0]
def _init_model(self, multi_model, validation):
# The benchmark runner will instantiate this class twice - once to determine
# the benchmarks to run, and once to actually to run them. However, Keras
# freezes if we try to load the same model twice. As such, we have to pull
# the model loading out of the constructor into a separate method which we
# call before each benchmark.
if multi_model:
metric_specs = metric_specs_util.specs_from_metrics(
[tf.keras.metrics.AUC(name="auc", num_thresholds=10000)],
model_names=["candidate", "baseline"])
if validation:
# Only one metric, adding a threshold for all slices.
metric_specs[0].metrics[0].threshold.CopyFrom(
tfma.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={"value": 0.5}, upper_bound={"value": 0.5}),
change_threshold=tfma.GenericChangeThreshold(
absolute={"value": -0.001},
direction=tfma.MetricDirection.HIGHER_IS_BETTER)))
self._eval_config = tfma.EvalConfig(
model_specs=[
tfma.ModelSpec(name="candidate", label_key="tips"),
tfma.ModelSpec(
name="baseline", label_key="tips", is_baseline=True)
],
metrics_specs=metric_specs)
self._eval_shared_models = {
"candidate":
tfma.default_eval_shared_model(
self._dataset.trained_saved_model_path(),
eval_config=self._eval_config,
model_name="candidate"),
"baseline":
tfma.default_eval_shared_model(
self._dataset.trained_saved_model_path(),
eval_config=self._eval_config,
model_name="baseline")
}
else:
metric_specs = metric_specs_util.specs_from_metrics(
[tf.keras.metrics.AUC(name="auc", num_thresholds=10000)])
if validation:
# Only one metric, adding a threshold for all slices.
metric_specs[0].metrics[0].threshold.CopyFrom(
tfma.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={"value": 0.5}, upper_bound={"value": 0.5})))
self._eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(label_key="tips")],
metrics_specs=metric_specs)
self._eval_shared_models = {
"":
tfma.default_eval_shared_model(
self._dataset.trained_saved_model_path(),
eval_config=self._eval_config)
}