in tfx_bsl/beam/run_inference.py [0:0]
def _post_process(
self,
examples: List[Union[tf.train.Example, bytes]],
serialized_examples: List[bytes],
outputs: Mapping[Text, np.ndarray]
) -> List[prediction_log_pb2.PredictionLog]:
del serialized_examples
classifications = None
regressions = None
for signature in self._signatures:
signature_def = signature.signature_def
if signature_def.method_name == tf.saved_model.CLASSIFY_METHOD_NAME:
classifications = _post_process_classify(
self._io_tensor_spec.output_alias_tensor_names, examples, outputs)
elif signature_def.method_name == tf.saved_model.REGRESS_METHOD_NAME:
regressions = _post_process_regress(examples, outputs)
else:
raise ValueError('Signature method %s is not supported for '
'multi inference' % signature_def.method_name)
result = []
for i, example in enumerate(examples):
prediction_log = prediction_log_pb2.PredictionLog()
input_example = (prediction_log.multi_inference_log.request.input
.example_list.examples.add())
(input_example.ParseFromString
if isinstance(example, bytes)
else input_example.CopyFrom)(example)
response = prediction_log.multi_inference_log.response
for signature in self._signatures:
signature_def = signature.signature_def
inference_result = response.results.add()
if (signature_def.method_name == tf.saved_model.CLASSIFY_METHOD_NAME and
classifications):
inference_result.classification_result.classifications.add().CopyFrom(
classifications[i])
elif (
signature_def.method_name == tf.saved_model.REGRESS_METHOD_NAME and
regressions):
inference_result.regression_result.regressions.add().CopyFrom(
regressions[i])
else:
raise ValueError('Signature method %s is not supported for '
'multi inference' % signature_def.method_name)
inference_result.model_spec.signature_name = signature.name
if len(response.results) != len(self._signatures):
raise RuntimeError('Multi inference response result length does not '
'match the number of signatures')
result.append(prediction_log)
return result