def _post_process()

in tfx_bsl/beam/run_inference.py [0:0]


  def _post_process(
      self,
      examples: List[Union[tf.train.Example, bytes]],
      serialized_examples: List[bytes],
      outputs: Mapping[Text, np.ndarray]
      ) -> List[prediction_log_pb2.PredictionLog]:
    del serialized_examples
    classifications = None
    regressions = None
    for signature in self._signatures:
      signature_def = signature.signature_def
      if signature_def.method_name == tf.saved_model.CLASSIFY_METHOD_NAME:
        classifications = _post_process_classify(
            self._io_tensor_spec.output_alias_tensor_names, examples, outputs)
      elif signature_def.method_name == tf.saved_model.REGRESS_METHOD_NAME:
        regressions = _post_process_regress(examples, outputs)
      else:
        raise ValueError('Signature method %s is not supported for '
                         'multi inference' % signature_def.method_name)
    result = []
    for i, example in enumerate(examples):
      prediction_log = prediction_log_pb2.PredictionLog()
      input_example = (prediction_log.multi_inference_log.request.input
                       .example_list.examples.add())
      (input_example.ParseFromString
       if isinstance(example, bytes)
       else input_example.CopyFrom)(example)
      response = prediction_log.multi_inference_log.response
      for signature in self._signatures:
        signature_def = signature.signature_def
        inference_result = response.results.add()
        if (signature_def.method_name == tf.saved_model.CLASSIFY_METHOD_NAME and
            classifications):
          inference_result.classification_result.classifications.add().CopyFrom(
              classifications[i])
        elif (
            signature_def.method_name == tf.saved_model.REGRESS_METHOD_NAME and
            regressions):
          inference_result.regression_result.regressions.add().CopyFrom(
              regressions[i])
        else:
          raise ValueError('Signature method %s is not supported for '
                           'multi inference' % signature_def.method_name)
        inference_result.model_spec.signature_name = signature.name
      if len(response.results) != len(self._signatures):
        raise RuntimeError('Multi inference response result length does not '
                           'match the number of signatures')
      result.append(prediction_log)
    return result