def default_eval_shared_model()

in tensorflow_model_analysis/api/model_eval_lib.py [0:0]


def default_eval_shared_model(
    eval_saved_model_path: str,
    add_metrics_callbacks: Optional[List[types.AddMetricsCallbackType]] = None,
    include_default_metrics: Optional[bool] = True,
    example_weight_key: Optional[Union[str, Dict[str, str]]] = None,
    additional_fetches: Optional[List[str]] = None,
    blacklist_feature_fetches: Optional[List[str]] = None,
    tags: Optional[List[str]] = None,
    model_name: str = '',
    eval_config: Optional[config_pb2.EvalConfig] = None,
    custom_model_loader: Optional[types.ModelLoader] = None,
    rubber_stamp: Optional[bool] = False) -> types.EvalSharedModel:
  """Returns default EvalSharedModel.

  Args:
    eval_saved_model_path: Path to EvalSavedModel.
    add_metrics_callbacks: Optional list of callbacks for adding additional
      metrics to the graph (see EvalSharedModel for more information on how to
      configure additional metrics). Metrics for example count and example
      weights will be added automatically. Only used if EvalSavedModel used.
    include_default_metrics: DEPRECATED. Use
      eval_config.options.include_default_metrics.
    example_weight_key: DEPRECATED. Use
      eval_config.model_specs.example_weight_key or
      eval_config.model_specs.example_weight_keys.
    additional_fetches: Optional prefixes of additional tensors stored in
      signature_def.inputs that should be fetched at prediction time. The
      "features" and "labels" tensors are handled automatically and should not
      be included. Only used if EvalSavedModel used.
    blacklist_feature_fetches: Optional list of tensor names in the features
      dictionary which should be excluded from the fetches request. This is
      useful in scenarios where features are large (e.g. images) and can lead to
      excessive memory use if stored. Only used if EvalSavedModel used.
    tags: Optional model tags (e.g. 'serve' for serving or 'eval' for
      EvalSavedModel).
    model_name: Optional name of the model being created (should match
      ModelSpecs.name). The name should only be provided if multiple models are
      being evaluated.
    eval_config: Eval config.
    custom_model_loader: Optional custom model loader for non-TF models.
    rubber_stamp: True when this run is a first run without a baseline model
      while a baseline is configured, the diff thresholds will be ignored.
  """
  if not eval_config:
    is_baseline = False
    model_type = constants.TF_ESTIMATOR
    if tags is None:
      tags = [eval_constants.EVAL_TAG]
  else:
    model_spec = model_util.get_model_spec(eval_config, model_name)
    if not model_spec:
      raise ValueError('ModelSpec for model name {} not found in EvalConfig: '
                       'config={}'.format(model_name, eval_config))
    is_baseline = model_spec.is_baseline
    model_type = model_util.get_model_type(model_spec, eval_saved_model_path,
                                           tags)
    if tags is None:
      # Default to serving unless estimator is used.
      if model_type == constants.TF_ESTIMATOR:
        tags = [eval_constants.EVAL_TAG]
      else:
        tags = [tf.saved_model.SERVING]
    if model_spec.example_weight_key or model_spec.example_weight_keys:
      example_weight_key = (
          model_spec.example_weight_key or model_spec.example_weight_keys)
    if eval_config.options.HasField('include_default_metrics'):
      include_default_metrics = (
          eval_config.options.include_default_metrics.value)

  # Backwards compatibility for legacy add_metrics_callbacks implementation.
  if model_type == constants.TF_ESTIMATOR and eval_constants.EVAL_TAG in tags:
    # PyType doesn't know about the magic exports we do in post_export_metrics.
    # Additionally, the lines seem to get reordered in compilation, so we can't
    # just put the disable-attr on the add_metrics_callbacks lines.
    # pytype: disable=module-attr
    if not add_metrics_callbacks:
      add_metrics_callbacks = []
    if include_default_metrics:
      # Always compute example weight and example count if default metrics are
      # enabled.
      example_count_callback = post_export_metrics.example_count()
      add_metrics_callbacks.append(example_count_callback)
      if example_weight_key:
        if isinstance(example_weight_key, dict):
          for output_name, key in example_weight_key.items():
            example_weight_callback = post_export_metrics.example_weight(
                key, metric_tag=output_name)
            add_metrics_callbacks.append(example_weight_callback)
        else:
          example_weight_callback = post_export_metrics.example_weight(
              example_weight_key)
          add_metrics_callbacks.append(example_weight_callback)
    # pytype: enable=module-attr

  model_loader = custom_model_loader
  if not model_loader and model_type in constants.VALID_TF_MODEL_TYPES:
    model_loader = types.ModelLoader(
        construct_fn=model_util.model_construct_fn(
            eval_saved_model_path=eval_saved_model_path,
            add_metrics_callbacks=add_metrics_callbacks,
            include_default_metrics=include_default_metrics,
            additional_fetches=additional_fetches,
            blacklist_feature_fetches=blacklist_feature_fetches,
            model_type=model_type,
            tags=tags),
        tags=tags)

  return types.EvalSharedModel(
      model_name=model_name,
      model_type=model_type,
      model_path=eval_saved_model_path,
      add_metrics_callbacks=add_metrics_callbacks,
      include_default_metrics=include_default_metrics,
      example_weight_key=example_weight_key,
      additional_fetches=additional_fetches,
      model_loader=model_loader,
      rubber_stamp=rubber_stamp,
      is_baseline=is_baseline)