def default_extractors()

in tensorflow_model_analysis/api/model_eval_lib.py [0:0]


def default_extractors(  # pylint: disable=invalid-name
    eval_shared_model: Optional[types.MaybeMultipleEvalSharedModels] = None,
    eval_config: Optional[config_pb2.EvalConfig] = None,
    slice_spec: Optional[List[slicer.SingleSliceSpec]] = None,
    materialize: Optional[bool] = None,
    tensor_adapter_config: Optional[tensor_adapter.TensorAdapterConfig] = None,
    custom_predict_extractor: Optional[extractor.Extractor] = None,
    config_version: Optional[int] = None) -> List[extractor.Extractor]:
  """Returns the default extractors for use in ExtractAndEvaluate.

  Args:
    eval_shared_model: Shared model (single-model evaluation) or list of shared
      models (multi-model evaluation). Required unless the predictions are
      provided alongside of the features (i.e. model-agnostic evaluations).
    eval_config: Eval config.
    slice_spec: Deprecated (use EvalConfig).
    materialize: True to have extractors create materialized output.
    tensor_adapter_config: Tensor adapter config which specifies how to obtain
      tensors from the Arrow RecordBatch. The model's signature will be invoked
      with those tensors (matched by names). If None, an attempt will be made to
      create an adapter based on the model's input signature otherwise the model
      will be invoked with raw examples (assuming a  signature of a single 1-D
      string tensor).
    custom_predict_extractor: Optional custom predict extractor for non-TF
      models.
    config_version: Optional config version for this evaluation. This should not
      be explicitly set by users. It is only intended to be used in cases where
      the provided eval_config was generated internally, and thus not a reliable
      indicator of user intent.

  Raises:
    NotImplementedError: If eval_config contains mixed serving and eval models.
  """
  if materialize is None:
    # TODO(b/172969312): Once analysis table is supported, remove defaulting
    #  to false unless 'analysis' is in disabled_outputs.
    materialize = False
  if slice_spec and eval_config:
    raise ValueError('slice_spec is deprecated, only use eval_config')
  if eval_config is not None:
    eval_config = _update_eval_config_with_defaults(eval_config,
                                                    eval_shared_model)

  if _is_legacy_eval(config_version, eval_shared_model, eval_config):
    # Backwards compatibility for previous add_metrics_callbacks implementation.
    if not eval_config and slice_spec:
      eval_config = config_pb2.EvalConfig(
          slicing_specs=[s.to_proto() for s in slice_spec])
    return [
        custom_predict_extractor or legacy_predict_extractor.PredictExtractor(
            eval_shared_model, materialize=materialize),
        slice_key_extractor.SliceKeyExtractor(
            eval_config=eval_config, materialize=materialize)
    ]
  slicing_extractors = []
  if _has_sql_slices(eval_config):
    slicing_extractors.append(
        sql_slice_key_extractor.SqlSliceKeyExtractor(eval_config))
  slicing_extractors.extend([
      unbatch_extractor.UnbatchExtractor(),
      slice_key_extractor.SliceKeyExtractor(
          eval_config=eval_config, materialize=materialize)
  ])
  if eval_shared_model:
    model_types = _model_types(eval_shared_model)
    eval_shared_models = model_util.verify_and_update_eval_shared_models(
        eval_shared_model)

    if (not model_types.issubset(constants.VALID_TF_MODEL_TYPES) and
        not custom_predict_extractor):
      raise NotImplementedError(
          'either a custom_predict_extractor must be used or model type must '
          'be one of: {}. evalconfig={}'.format(
              str(constants.VALID_TF_MODEL_TYPES), eval_config))
    if model_types == set([constants.TF_LITE]):
      # TODO(b/163889779): Convert TFLite extractor to operate on batched
      # extracts. Then we can remove the input extractor.
      return [
          features_extractor.FeaturesExtractor(eval_config=eval_config),
          transformed_features_extractor.TransformedFeaturesExtractor(
              eval_config=eval_config,
              eval_shared_model=eval_shared_model,
              tensor_adapter_config=tensor_adapter_config),
          labels_extractor.LabelsExtractor(eval_config=eval_config),
          example_weights_extractor.ExampleWeightsExtractor(
              eval_config=eval_config),
          (custom_predict_extractor or
           tflite_predict_extractor.TFLitePredictExtractor(
               eval_config=eval_config, eval_shared_model=eval_shared_model))
      ] + slicing_extractors
    elif constants.TF_LITE in model_types:
      raise NotImplementedError(
          'support for mixing tf_lite and non-tf_lite models is not '
          'implemented: eval_config={}'.format(eval_config))

    if model_types == set([constants.TF_JS]):
      return [
          features_extractor.FeaturesExtractor(eval_config=eval_config),
          labels_extractor.LabelsExtractor(eval_config=eval_config),
          example_weights_extractor.ExampleWeightsExtractor(
              eval_config=eval_config),
          (custom_predict_extractor or
           tfjs_predict_extractor.TFJSPredictExtractor(
               eval_config=eval_config, eval_shared_model=eval_shared_model))
      ] + slicing_extractors
    elif constants.TF_JS in model_types:
      raise NotImplementedError(
          'support for mixing tf_js and non-tf_js models is not '
          'implemented: eval_config={}'.format(eval_config))

    elif (eval_config and model_types == set([constants.TF_ESTIMATOR]) and
          all(eval_constants.EVAL_TAG in m.model_loader.tags
              for m in eval_shared_models)):
      return [
          custom_predict_extractor or legacy_predict_extractor.PredictExtractor(
              eval_shared_model,
              materialize=materialize,
              eval_config=eval_config)
      ] + slicing_extractors
    elif (eval_config and constants.TF_ESTIMATOR in model_types and
          any(eval_constants.EVAL_TAG in m.model_loader.tags
              for m in eval_shared_models)):
      raise NotImplementedError(
          'support for mixing eval and non-eval estimator models is not '
          'implemented: eval_config={}'.format(eval_config))
    else:
      extractors = [
          features_extractor.FeaturesExtractor(eval_config=eval_config)
      ]
      if not custom_predict_extractor:
        extractors.append(
            transformed_features_extractor.TransformedFeaturesExtractor(
                eval_config=eval_config,
                eval_shared_model=eval_shared_model,
                tensor_adapter_config=tensor_adapter_config))
      extractors.extend([
          labels_extractor.LabelsExtractor(eval_config=eval_config),
          example_weights_extractor.ExampleWeightsExtractor(
              eval_config=eval_config),
          (custom_predict_extractor or
           predictions_extractor.PredictionsExtractor(
               eval_config=eval_config,
               eval_shared_model=eval_shared_model,
               tensor_adapter_config=tensor_adapter_config)),
      ])
      extractors.extend(slicing_extractors)
      return extractors
  else:
    return [
        features_extractor.FeaturesExtractor(eval_config=eval_config),
        labels_extractor.LabelsExtractor(eval_config=eval_config),
        example_weights_extractor.ExampleWeightsExtractor(
            eval_config=eval_config),
        predictions_extractor.PredictionsExtractor(eval_config=eval_config)
    ] + slicing_extractors