in tensorflow_model_analysis/api/model_eval_lib.py [0:0]
def default_eval_shared_model(
eval_saved_model_path: str,
add_metrics_callbacks: Optional[List[types.AddMetricsCallbackType]] = None,
include_default_metrics: Optional[bool] = True,
example_weight_key: Optional[Union[str, Dict[str, str]]] = None,
additional_fetches: Optional[List[str]] = None,
blacklist_feature_fetches: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
model_name: str = '',
eval_config: Optional[config_pb2.EvalConfig] = None,
custom_model_loader: Optional[types.ModelLoader] = None,
rubber_stamp: Optional[bool] = False) -> types.EvalSharedModel:
"""Returns default EvalSharedModel.
Args:
eval_saved_model_path: Path to EvalSavedModel.
add_metrics_callbacks: Optional list of callbacks for adding additional
metrics to the graph (see EvalSharedModel for more information on how to
configure additional metrics). Metrics for example count and example
weights will be added automatically. Only used if EvalSavedModel used.
include_default_metrics: DEPRECATED. Use
eval_config.options.include_default_metrics.
example_weight_key: DEPRECATED. Use
eval_config.model_specs.example_weight_key or
eval_config.model_specs.example_weight_keys.
additional_fetches: Optional prefixes of additional tensors stored in
signature_def.inputs that should be fetched at prediction time. The
"features" and "labels" tensors are handled automatically and should not
be included. Only used if EvalSavedModel used.
blacklist_feature_fetches: Optional list of tensor names in the features
dictionary which should be excluded from the fetches request. This is
useful in scenarios where features are large (e.g. images) and can lead to
excessive memory use if stored. Only used if EvalSavedModel used.
tags: Optional model tags (e.g. 'serve' for serving or 'eval' for
EvalSavedModel).
model_name: Optional name of the model being created (should match
ModelSpecs.name). The name should only be provided if multiple models are
being evaluated.
eval_config: Eval config.
custom_model_loader: Optional custom model loader for non-TF models.
rubber_stamp: True when this run is a first run without a baseline model
while a baseline is configured, the diff thresholds will be ignored.
"""
if not eval_config:
is_baseline = False
model_type = constants.TF_ESTIMATOR
if tags is None:
tags = [eval_constants.EVAL_TAG]
else:
model_spec = model_util.get_model_spec(eval_config, model_name)
if not model_spec:
raise ValueError('ModelSpec for model name {} not found in EvalConfig: '
'config={}'.format(model_name, eval_config))
is_baseline = model_spec.is_baseline
model_type = model_util.get_model_type(model_spec, eval_saved_model_path,
tags)
if tags is None:
# Default to serving unless estimator is used.
if model_type == constants.TF_ESTIMATOR:
tags = [eval_constants.EVAL_TAG]
else:
tags = [tf.saved_model.SERVING]
if model_spec.example_weight_key or model_spec.example_weight_keys:
example_weight_key = (
model_spec.example_weight_key or model_spec.example_weight_keys)
if eval_config.options.HasField('include_default_metrics'):
include_default_metrics = (
eval_config.options.include_default_metrics.value)
# Backwards compatibility for legacy add_metrics_callbacks implementation.
if model_type == constants.TF_ESTIMATOR and eval_constants.EVAL_TAG in tags:
# PyType doesn't know about the magic exports we do in post_export_metrics.
# Additionally, the lines seem to get reordered in compilation, so we can't
# just put the disable-attr on the add_metrics_callbacks lines.
# pytype: disable=module-attr
if not add_metrics_callbacks:
add_metrics_callbacks = []
if include_default_metrics:
# Always compute example weight and example count if default metrics are
# enabled.
example_count_callback = post_export_metrics.example_count()
add_metrics_callbacks.append(example_count_callback)
if example_weight_key:
if isinstance(example_weight_key, dict):
for output_name, key in example_weight_key.items():
example_weight_callback = post_export_metrics.example_weight(
key, metric_tag=output_name)
add_metrics_callbacks.append(example_weight_callback)
else:
example_weight_callback = post_export_metrics.example_weight(
example_weight_key)
add_metrics_callbacks.append(example_weight_callback)
# pytype: enable=module-attr
model_loader = custom_model_loader
if not model_loader and model_type in constants.VALID_TF_MODEL_TYPES:
model_loader = types.ModelLoader(
construct_fn=model_util.model_construct_fn(
eval_saved_model_path=eval_saved_model_path,
add_metrics_callbacks=add_metrics_callbacks,
include_default_metrics=include_default_metrics,
additional_fetches=additional_fetches,
blacklist_feature_fetches=blacklist_feature_fetches,
model_type=model_type,
tags=tags),
tags=tags)
return types.EvalSharedModel(
model_name=model_name,
model_type=model_type,
model_path=eval_saved_model_path,
add_metrics_callbacks=add_metrics_callbacks,
include_default_metrics=include_default_metrics,
example_weight_key=example_weight_key,
additional_fetches=additional_fetches,
model_loader=model_loader,
rubber_stamp=rubber_stamp,
is_baseline=is_baseline)