in tensorflow_serving/servables/tensorflow/tflite_session.cc [413:530]
Status TfLiteSession::Create(string&& buffer, const SessionOptions& options,
int num_pools, int num_interpreters_per_pool,
std::unique_ptr<TfLiteSession>* tflite_session,
::google::protobuf::Map<string, SignatureDef>* signatures) {
auto model = tflite::FlatBufferModel::BuildFromModel(
flatbuffers::GetRoot<tflite::Model>(buffer.data()));
if (model == nullptr) {
return errors::InvalidArgument("Cannot build FlatBufferModel from buffer.");
}
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::ops::custom::AddParseExampleOp(&resolver);
std::unique_ptr<tflite::Interpreter> interpreter;
if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
return errors::Internal("Cannot build Interpreter from buffer.");
}
TensorInfoMap inputs;
TF_RETURN_IF_ERROR(GetTensorInfoMap(interpreter.get(), true, &inputs));
TensorInfoMap outputs;
TF_RETURN_IF_ERROR(GetTensorInfoMap(interpreter.get(), false, &outputs));
// Map of TFLite tensor name -> tensor index
std::map<string, int> input_tensor_to_index;
std::map<string, int> output_tensor_to_index;
for (const auto& info : inputs) {
const string& tflite_tensor_name = info.first;
input_tensor_to_index[tflite_tensor_name] = info.second.second;
}
for (const auto& info : outputs) {
const string& tflite_tensor_name = info.first;
output_tensor_to_index[tflite_tensor_name] = info.second.second;
}
// Attempt to read signature defs from the model file
std::map<string, SignatureDef> signature_defs;
const auto status =
tflite::GetSignatureDefMap(model->GetModel(), &signature_defs);
if (status != Status::OK()) {
return errors::InvalidArgument(
"Invalid SignatureDefs found in TfLite model: ",
status.error_message());
}
const bool has_lite_signature_def = !signature_defs.empty();
signatures->clear();
if (has_lite_signature_def) {
// Check that input/output tensors in the signature defs refer to existing
// tensors.
// If not found, try to match with legacy TFLite name (without suffix).
for (const auto& signature_item : signature_defs) {
SignatureDef* tflite_signature = &(*signatures)[signature_item.first];
tflite_signature->CopyFrom(signature_item.second);
for (auto& input : *tflite_signature->mutable_inputs()) {
TensorInfo* tensor_info = &input.second;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
FixTfLiteTensorName(input_tensor_to_index,
*tensor_info->mutable_name()),
"Signature input ", input.first, " references an unknown tensor");
}
for (auto& output : *tflite_signature->mutable_outputs()) {
TensorInfo* tensor_info = &output.second;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
FixTfLiteTensorName(output_tensor_to_index,
*tensor_info->mutable_name()),
"Signature output ", output.first, " references an unknown tensor");
}
}
} else {
// Build a mock signature from the input/output tensors of the model.
// TODO(b/169239308)
LOG(WARNING) << "No signature def found in TFLite model. Generating one.";
SignatureDef* sigdef = &(*signatures)[kDefaultServingSignatureDefKey];
for (const auto& info : inputs) {
string tflite_tensor_name = TfToTfLiteLegacyTensorName(info.first);
(*sigdef->mutable_inputs())[tflite_tensor_name] = info.second.first;
}
for (const auto& info : outputs) {
string tflite_tensor_name = TfToTfLiteLegacyTensorName(info.first);
(*sigdef->mutable_outputs())[tflite_tensor_name] = info.second.first;
}
sigdef->set_method_name(kPredictMethodName);
}
const int num_interpreters = std::max(1, num_pools);
const int model_batch_size = GetModelBatchSize(model->GetModel());
std::unique_ptr<internal::TfLiteInterpreterPool> interpreter_pool;
TF_RETURN_IF_ERROR(
internal::TfLiteInterpreterPool::CreateTfLiteInterpreterPool(
model.get(), options, num_interpreters, interpreter_pool));
tflite_session->reset(new TfLiteSession(
std::move(input_tensor_to_index), std::move(output_tensor_to_index),
std::move(buffer), std::move(model), std::move(interpreter_pool)));
if (num_interpreters_per_pool > 1) {
const int default_allowed_batch =
(internal::kInitialBatchSize + num_interpreters_per_pool - 1) /
num_interpreters_per_pool;
const int min_allowed_batch =
model_batch_size > 1 ? model_batch_size : default_allowed_batch;
const int max_enqueued_batches = num_interpreters * 100;
BasicBatchScheduler<TfLiteBatchTask>::Options scheduler_options;
scheduler_options.num_batch_threads = num_interpreters;
scheduler_options.max_batch_size = internal::kInitialBatchSize;
scheduler_options.enable_large_batch_splitting = true;
scheduler_options.max_execution_batch_size = min_allowed_batch;
scheduler_options.max_enqueued_batches = max_enqueued_batches;
scheduler_options.split_input_task_func = SplitTfLiteInputTask;
TF_RETURN_IF_ERROR(
(*tflite_session)
->SetScheduler(&TfLiteSession::CreateDefaultBasicBatchScheduler,
scheduler_options));
}
return Status::OK();
}