Status TfLiteSession::Create()

in tensorflow_serving/servables/tensorflow/tflite_session.cc [413:530]


Status TfLiteSession::Create(string&& buffer, const SessionOptions& options,
                             int num_pools, int num_interpreters_per_pool,
                             std::unique_ptr<TfLiteSession>* tflite_session,
                             ::google::protobuf::Map<string, SignatureDef>* signatures) {
  auto model = tflite::FlatBufferModel::BuildFromModel(
      flatbuffers::GetRoot<tflite::Model>(buffer.data()));
  if (model == nullptr) {
    return errors::InvalidArgument("Cannot build FlatBufferModel from buffer.");
  }

  tflite::ops::builtin::BuiltinOpResolver resolver;
  tflite::ops::custom::AddParseExampleOp(&resolver);

  std::unique_ptr<tflite::Interpreter> interpreter;
  if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
    return errors::Internal("Cannot build Interpreter from buffer.");
  }

  TensorInfoMap inputs;
  TF_RETURN_IF_ERROR(GetTensorInfoMap(interpreter.get(), true, &inputs));
  TensorInfoMap outputs;
  TF_RETURN_IF_ERROR(GetTensorInfoMap(interpreter.get(), false, &outputs));

  // Map of TFLite tensor name -> tensor index
  std::map<string, int> input_tensor_to_index;
  std::map<string, int> output_tensor_to_index;
  for (const auto& info : inputs) {
    const string& tflite_tensor_name = info.first;
    input_tensor_to_index[tflite_tensor_name] = info.second.second;
  }
  for (const auto& info : outputs) {
    const string& tflite_tensor_name = info.first;
    output_tensor_to_index[tflite_tensor_name] = info.second.second;
  }

  // Attempt to read signature defs from the model file
  std::map<string, SignatureDef> signature_defs;
  const auto status =
      tflite::GetSignatureDefMap(model->GetModel(), &signature_defs);
  if (status != Status::OK()) {
    return errors::InvalidArgument(
        "Invalid SignatureDefs found in TfLite model: ",
        status.error_message());
  }
  const bool has_lite_signature_def = !signature_defs.empty();

  signatures->clear();
  if (has_lite_signature_def) {
    // Check that input/output tensors in the signature defs refer to existing
    // tensors.
    // If not found, try to match with legacy TFLite name (without suffix).
    for (const auto& signature_item : signature_defs) {
      SignatureDef* tflite_signature = &(*signatures)[signature_item.first];
      tflite_signature->CopyFrom(signature_item.second);
      for (auto& input : *tflite_signature->mutable_inputs()) {
        TensorInfo* tensor_info = &input.second;
        TF_RETURN_WITH_CONTEXT_IF_ERROR(
            FixTfLiteTensorName(input_tensor_to_index,
                                *tensor_info->mutable_name()),
            "Signature input ", input.first, " references an unknown tensor");
      }
      for (auto& output : *tflite_signature->mutable_outputs()) {
        TensorInfo* tensor_info = &output.second;
        TF_RETURN_WITH_CONTEXT_IF_ERROR(
            FixTfLiteTensorName(output_tensor_to_index,
                                *tensor_info->mutable_name()),
            "Signature output ", output.first, " references an unknown tensor");
      }
    }
  } else {
    // Build a mock signature from the input/output tensors of the model.
    // TODO(b/169239308)
    LOG(WARNING) << "No signature def found in TFLite model. Generating one.";
    SignatureDef* sigdef = &(*signatures)[kDefaultServingSignatureDefKey];
    for (const auto& info : inputs) {
      string tflite_tensor_name = TfToTfLiteLegacyTensorName(info.first);
      (*sigdef->mutable_inputs())[tflite_tensor_name] = info.second.first;
    }
    for (const auto& info : outputs) {
      string tflite_tensor_name = TfToTfLiteLegacyTensorName(info.first);
      (*sigdef->mutable_outputs())[tflite_tensor_name] = info.second.first;
    }
    sigdef->set_method_name(kPredictMethodName);
  }

  const int num_interpreters = std::max(1, num_pools);
  const int model_batch_size = GetModelBatchSize(model->GetModel());

  std::unique_ptr<internal::TfLiteInterpreterPool> interpreter_pool;
  TF_RETURN_IF_ERROR(
      internal::TfLiteInterpreterPool::CreateTfLiteInterpreterPool(
          model.get(), options, num_interpreters, interpreter_pool));

  tflite_session->reset(new TfLiteSession(
      std::move(input_tensor_to_index), std::move(output_tensor_to_index),
      std::move(buffer), std::move(model), std::move(interpreter_pool)));

  if (num_interpreters_per_pool > 1) {
    const int default_allowed_batch =
        (internal::kInitialBatchSize + num_interpreters_per_pool - 1) /
        num_interpreters_per_pool;
    const int min_allowed_batch =
        model_batch_size > 1 ? model_batch_size : default_allowed_batch;
    const int max_enqueued_batches = num_interpreters * 100;
    BasicBatchScheduler<TfLiteBatchTask>::Options scheduler_options;
    scheduler_options.num_batch_threads = num_interpreters;
    scheduler_options.max_batch_size = internal::kInitialBatchSize;
    scheduler_options.enable_large_batch_splitting = true;
    scheduler_options.max_execution_batch_size = min_allowed_batch;
    scheduler_options.max_enqueued_batches = max_enqueued_batches;
    scheduler_options.split_input_task_func = SplitTfLiteInputTask;
    TF_RETURN_IF_ERROR(
        (*tflite_session)
            ->SetScheduler(&TfLiteSession::CreateDefaultBasicBatchScheduler,
                           scheduler_options));
  }
  return Status::OK();
}