Status TensorFlowMultiInferenceRunner::Infer()

in tensorflow_serving/servables/tensorflow/multi_inference.cc [30:124]


Status TensorFlowMultiInferenceRunner::Infer(
    const RunOptions& run_options, const MultiInferenceRequest& request,
    MultiInferenceResponse* response) {
  TRACELITERAL("TensorFlowMultiInferenceRunner::Infer");

  string model_name = "";
  std::set<string> signature_names;
  std::set<string> input_tensor_name_set;
  std::set<string> output_tensor_name_set;
  for (const auto& task : request.tasks()) {
    if (task.model_spec().name().empty()) {
      return errors::InvalidArgument(
          "Found ModelSpec with an empty model name.");
    }
    if (model_name.empty()) {
      model_name = task.model_spec().name();
    } else if (model_name != task.model_spec().name()) {
      return errors::InvalidArgument(
          "All ModelSpecs in a MultiInferenceRequest must access the same "
          "model name.");
    }

    const string signature_name = task.model_spec().signature_name().empty()
                                      ? kDefaultServingSignatureDefKey
                                      : task.model_spec().signature_name();

    if (signature_names.find(signature_name) != signature_names.end()) {
      return errors::InvalidArgument(strings::StrCat(
          "Duplicate evaluation of signature: ", signature_name));
    }
    signature_names.insert(signature_name);

    auto iter = meta_graph_def_->signature_def().find(signature_name);
    if (iter == meta_graph_def_->signature_def().end()) {
      return errors::InvalidArgument(strings::StrCat(
          "Requested signature not found in model graph: ", signature_name));
    }
    string input_name;
    std::vector<string> output_names;

    if (task.method_name() == kClassifyMethodName) {
      TF_RETURN_IF_ERROR(
          PreProcessClassification(iter->second, &input_name, &output_names));
    } else if (task.method_name() == kRegressMethodName) {
      TF_RETURN_IF_ERROR(
          PreProcessRegression(iter->second, &input_name, &output_names));
    } else {
      return errors::Unimplemented("Unsupported signature method_name: ",
                                   task.method_name());
    }
    input_tensor_name_set.insert(input_name);
    for (const auto& output_tensor_name : output_names) {
      output_tensor_name_set.insert(output_tensor_name);
    }
  }

  const std::vector<string> output_tensor_names(output_tensor_name_set.begin(),
                                                output_tensor_name_set.end());

  std::vector<Tensor> outputs;
  int num_examples;
  TF_RETURN_IF_ERROR(PerformOneShotTensorComputation(
      run_options, request.input(), input_tensor_name_set, output_tensor_names,
      session_, &outputs, &num_examples, thread_pool_options_));
  RecordRequestExampleCount(model_name, num_examples);

  TRACELITERAL("PostProcessResults");
  for (const auto& task : request.tasks()) {
    const string signature_name = task.model_spec().signature_name().empty()
                                      ? kDefaultServingSignatureDefKey
                                      : task.model_spec().signature_name();
    auto iter = meta_graph_def_->signature_def().find(signature_name);
    if (iter == meta_graph_def_->signature_def().end()) {
      return errors::InvalidArgument(strings::StrCat(
          "Requested signature not found in model graph: ", signature_name));
    }
    if (task.method_name() == kClassifyMethodName) {
      TF_RETURN_IF_ERROR(PostProcessClassificationResult(
          iter->second, num_examples, output_tensor_names, outputs,
          response->add_results()->mutable_classification_result()));
    } else if (task.method_name() == kRegressMethodName) {
      TF_RETURN_IF_ERROR(PostProcessRegressionResult(
          iter->second, num_examples, output_tensor_names, outputs,
          response->add_results()->mutable_regression_result()));
    } else {
      return errors::InvalidArgument("Unrecognized signature method_name: ",
                                     task.method_name());
    }
    MakeModelSpec(task.model_spec().name(), task.model_spec().signature_name(),
                  servable_version_,
                  response->mutable_results(response->results_size() - 1)
                      ->mutable_model_spec());
  }
  return Status::OK();
}