in tfx_bsl/beam/run_inference.py [0:0]
def _get_operation_type(
inference_spec_type: model_spec_pb2.InferenceSpecType) -> Text:
if _using_in_process_inference(inference_spec_type):
signatures = _get_signatures(
inference_spec_type.saved_model_spec.model_path,
inference_spec_type.saved_model_spec.signature_name,
_get_tags(inference_spec_type))
if not signatures:
raise ValueError('Model does not have valid signature to use')
if len(signatures) == 1:
method_name = signatures[0].signature_def.method_name
if method_name == tf.saved_model.CLASSIFY_METHOD_NAME:
return _OperationType.CLASSIFICATION
elif method_name == tf.saved_model.REGRESS_METHOD_NAME:
return _OperationType.REGRESSION
elif method_name == tf.saved_model.PREDICT_METHOD_NAME:
return _OperationType.PREDICTION
else:
raise ValueError('Unsupported signature method_name %s' % method_name)
else:
for signature in signatures:
method_name = signature.signature_def.method_name
if (method_name != tf.saved_model.CLASSIFY_METHOD_NAME and
method_name != tf.saved_model.REGRESS_METHOD_NAME):
raise ValueError('Unsupported signature method_name for multi-head '
'model inference: %s' % method_name)
return _OperationType.MULTI_INFERENCE
else:
# Remote inference supports predictions only.
return _OperationType.PREDICTION