Status PostProcessClassificationResult()

in tensorflow_serving/servables/tensorflow/classifier.cc [224:323]


Status PostProcessClassificationResult(
    const SignatureDef& signature, int num_examples,
    const std::vector<string>& output_tensor_names,
    const std::vector<Tensor>& output_tensors, ClassificationResult* result) {
  if (output_tensors.size() != output_tensor_names.size()) {
    return errors::InvalidArgument(
        strings::StrCat("Expected ", output_tensor_names.size(),
                        " output tensor(s).  Got: ", output_tensors.size()));
  }

  auto classes_iter = signature.outputs().find(kClassifyOutputClasses);
  string classes_tensor_name;
  if (classes_iter != signature.outputs().end()) {
    classes_tensor_name = classes_iter->second.name();
  }
  auto scores_iter = signature.outputs().find(kClassifyOutputScores);
  string scores_tensor_name;
  if (scores_iter != signature.outputs().end()) {
    scores_tensor_name = scores_iter->second.name();
  }

  const Tensor* classes = nullptr;
  const Tensor* scores = nullptr;
  for (int i = 0; i < output_tensors.size(); ++i) {
    if (output_tensor_names[i] == classes_tensor_name) {
      classes = &output_tensors[i];
    } else if (output_tensor_names[i] == scores_tensor_name) {
      scores = &output_tensors[i];
    }
  }

  // Validate classes output Tensor.
  if (classes) {
    if (classes->dims() != 2) {
      return errors::InvalidArgument(
          "Expected Tensor shape: [batch_size num_classes] but got ",
          classes->shape().DebugString());
    }
    if (classes->dtype() != DT_STRING) {
      return errors::InvalidArgument(
          "Expected classes Tensor of DT_STRING. Got: ",
          DataType_Name(classes->dtype()));
    }
    if (classes->dim_size(0) != num_examples) {
      return errors::InvalidArgument("Expected classes output batch size of ",
                                     num_examples,
                                     ". Got: ", classes->dim_size(0));
    }
  }
  // Validate scores output Tensor.
  if (scores) {
    if (scores->dims() != 2) {
      return errors::InvalidArgument(
          "Expected Tensor shape: [batch_size num_classes] but got ",
          scores->shape().DebugString());
    }
    if (scores->dtype() != DT_FLOAT) {
      return errors::InvalidArgument(
          "Expected scores Tensor of DT_FLOAT. Got: ",
          DataType_Name(scores->dtype()));
    }
    if (scores->dim_size(0) != num_examples) {
      return errors::InvalidArgument("Expected scores output batch size of ",
                                     num_examples,
                                     ". Got: ", scores->dim_size(0));
    }
  }
  // Extract the number of classes from either the class or score output
  // Tensor.
  int num_classes = 0;
  if (classes && scores) {
    // If we have both Tensors they should agree in the second dimmension.
    if (classes->dim_size(1) != scores->dim_size(1)) {
      return errors::InvalidArgument(
          "Tensors class and score should match in dim_size(1). Got ",
          classes->dim_size(1), " vs. ", scores->dim_size(1));
    }
    num_classes = classes->dim_size(1);
  } else if (classes) {
    num_classes = classes->dim_size(1);
  } else if (scores) {
    num_classes = scores->dim_size(1);
  }

  // Convert the output to ClassificationResult format.
  for (int i = 0; i < num_examples; ++i) {
    serving::Classifications* classifications = result->add_classifications();
    for (int c = 0; c < num_classes; ++c) {
      serving::Class* cl = classifications->add_classes();
      if (classes) {
        const tstring& class_tstr = (classes->matrix<tstring>())(i, c);
        cl->set_label(class_tstr.data(), class_tstr.size());
      }
      if (scores) {
        cl->set_score((scores->matrix<float>())(i, c));
      }
    }
  }
  return Status::OK();
}