in tensorflow_model_analysis/api/model_eval_lib.py [0:0]
def default_extractors( # pylint: disable=invalid-name
eval_shared_model: Optional[types.MaybeMultipleEvalSharedModels] = None,
eval_config: Optional[config_pb2.EvalConfig] = None,
slice_spec: Optional[List[slicer.SingleSliceSpec]] = None,
materialize: Optional[bool] = None,
tensor_adapter_config: Optional[tensor_adapter.TensorAdapterConfig] = None,
custom_predict_extractor: Optional[extractor.Extractor] = None,
config_version: Optional[int] = None) -> List[extractor.Extractor]:
"""Returns the default extractors for use in ExtractAndEvaluate.
Args:
eval_shared_model: Shared model (single-model evaluation) or list of shared
models (multi-model evaluation). Required unless the predictions are
provided alongside of the features (i.e. model-agnostic evaluations).
eval_config: Eval config.
slice_spec: Deprecated (use EvalConfig).
materialize: True to have extractors create materialized output.
tensor_adapter_config: Tensor adapter config which specifies how to obtain
tensors from the Arrow RecordBatch. The model's signature will be invoked
with those tensors (matched by names). If None, an attempt will be made to
create an adapter based on the model's input signature otherwise the model
will be invoked with raw examples (assuming a signature of a single 1-D
string tensor).
custom_predict_extractor: Optional custom predict extractor for non-TF
models.
config_version: Optional config version for this evaluation. This should not
be explicitly set by users. It is only intended to be used in cases where
the provided eval_config was generated internally, and thus not a reliable
indicator of user intent.
Raises:
NotImplementedError: If eval_config contains mixed serving and eval models.
"""
if materialize is None:
# TODO(b/172969312): Once analysis table is supported, remove defaulting
# to false unless 'analysis' is in disabled_outputs.
materialize = False
if slice_spec and eval_config:
raise ValueError('slice_spec is deprecated, only use eval_config')
if eval_config is not None:
eval_config = _update_eval_config_with_defaults(eval_config,
eval_shared_model)
if _is_legacy_eval(config_version, eval_shared_model, eval_config):
# Backwards compatibility for previous add_metrics_callbacks implementation.
if not eval_config and slice_spec:
eval_config = config_pb2.EvalConfig(
slicing_specs=[s.to_proto() for s in slice_spec])
return [
custom_predict_extractor or legacy_predict_extractor.PredictExtractor(
eval_shared_model, materialize=materialize),
slice_key_extractor.SliceKeyExtractor(
eval_config=eval_config, materialize=materialize)
]
slicing_extractors = []
if _has_sql_slices(eval_config):
slicing_extractors.append(
sql_slice_key_extractor.SqlSliceKeyExtractor(eval_config))
slicing_extractors.extend([
unbatch_extractor.UnbatchExtractor(),
slice_key_extractor.SliceKeyExtractor(
eval_config=eval_config, materialize=materialize)
])
if eval_shared_model:
model_types = _model_types(eval_shared_model)
eval_shared_models = model_util.verify_and_update_eval_shared_models(
eval_shared_model)
if (not model_types.issubset(constants.VALID_TF_MODEL_TYPES) and
not custom_predict_extractor):
raise NotImplementedError(
'either a custom_predict_extractor must be used or model type must '
'be one of: {}. evalconfig={}'.format(
str(constants.VALID_TF_MODEL_TYPES), eval_config))
if model_types == set([constants.TF_LITE]):
# TODO(b/163889779): Convert TFLite extractor to operate on batched
# extracts. Then we can remove the input extractor.
return [
features_extractor.FeaturesExtractor(eval_config=eval_config),
transformed_features_extractor.TransformedFeaturesExtractor(
eval_config=eval_config,
eval_shared_model=eval_shared_model,
tensor_adapter_config=tensor_adapter_config),
labels_extractor.LabelsExtractor(eval_config=eval_config),
example_weights_extractor.ExampleWeightsExtractor(
eval_config=eval_config),
(custom_predict_extractor or
tflite_predict_extractor.TFLitePredictExtractor(
eval_config=eval_config, eval_shared_model=eval_shared_model))
] + slicing_extractors
elif constants.TF_LITE in model_types:
raise NotImplementedError(
'support for mixing tf_lite and non-tf_lite models is not '
'implemented: eval_config={}'.format(eval_config))
if model_types == set([constants.TF_JS]):
return [
features_extractor.FeaturesExtractor(eval_config=eval_config),
labels_extractor.LabelsExtractor(eval_config=eval_config),
example_weights_extractor.ExampleWeightsExtractor(
eval_config=eval_config),
(custom_predict_extractor or
tfjs_predict_extractor.TFJSPredictExtractor(
eval_config=eval_config, eval_shared_model=eval_shared_model))
] + slicing_extractors
elif constants.TF_JS in model_types:
raise NotImplementedError(
'support for mixing tf_js and non-tf_js models is not '
'implemented: eval_config={}'.format(eval_config))
elif (eval_config and model_types == set([constants.TF_ESTIMATOR]) and
all(eval_constants.EVAL_TAG in m.model_loader.tags
for m in eval_shared_models)):
return [
custom_predict_extractor or legacy_predict_extractor.PredictExtractor(
eval_shared_model,
materialize=materialize,
eval_config=eval_config)
] + slicing_extractors
elif (eval_config and constants.TF_ESTIMATOR in model_types and
any(eval_constants.EVAL_TAG in m.model_loader.tags
for m in eval_shared_models)):
raise NotImplementedError(
'support for mixing eval and non-eval estimator models is not '
'implemented: eval_config={}'.format(eval_config))
else:
extractors = [
features_extractor.FeaturesExtractor(eval_config=eval_config)
]
if not custom_predict_extractor:
extractors.append(
transformed_features_extractor.TransformedFeaturesExtractor(
eval_config=eval_config,
eval_shared_model=eval_shared_model,
tensor_adapter_config=tensor_adapter_config))
extractors.extend([
labels_extractor.LabelsExtractor(eval_config=eval_config),
example_weights_extractor.ExampleWeightsExtractor(
eval_config=eval_config),
(custom_predict_extractor or
predictions_extractor.PredictionsExtractor(
eval_config=eval_config,
eval_shared_model=eval_shared_model,
tensor_adapter_config=tensor_adapter_config)),
])
extractors.extend(slicing_extractors)
return extractors
else:
return [
features_extractor.FeaturesExtractor(eval_config=eval_config),
labels_extractor.LabelsExtractor(eval_config=eval_config),
example_weights_extractor.ExampleWeightsExtractor(
eval_config=eval_config),
predictions_extractor.PredictionsExtractor(eval_config=eval_config)
] + slicing_extractors