in tensorflow_serving/batching/batching_session.cc [363:445]
Status BatchingSession::InternalRun(
const RunOptions& run_options,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
RunMetadata* run_metadata,
absl::optional<thread::ThreadPoolOptions> thread_pool_options) {
if (!target_node_names.empty()) {
return errors::PermissionDenied(
"BatchingSession does not support target nodes");
}
profiler::TraceMe trace_me([this] {
return profiler::TraceMeEncode(
"BatchingSessionRun",
{{"thread_pool_name", thread_pool_name_}, {"_r", 1} /*root_event*/});
});
const TensorSignature signature =
TensorSignatureFromRunArgs(inputs, output_tensor_names);
auto batch_scheduler_it = batch_schedulers_.find(signature);
if (batch_scheduler_it == batch_schedulers_.end()) {
if (default_scheduler_creator_.has_value()) {
absl::MutexLock l(&mu_);
batch_scheduler_it = custom_signature_batch_schedulers_.find(signature);
if (batch_scheduler_it == custom_signature_batch_schedulers_.end()) {
std::unique_ptr<BatchScheduler<BatchingSessionTask>> batch_scheduler;
TF_RETURN_IF_ERROR(default_scheduler_creator_.value()(
[&, signature](std::unique_ptr<Batch<BatchingSessionTask>> batch) {
ProcessBatch(signature, std::move(batch));
},
&batch_scheduler));
custom_signature_batch_schedulers_[signature] =
std::move(batch_scheduler);
batch_scheduler_it = custom_signature_batch_schedulers_.find(signature);
}
} else {
// We have a Run() call that doesn't match one of our batching signatures.
// Run it in-line.
LOG_EVERY_N_SEC(WARNING, 120)
<< "Request doesn't match any declared signature and no default "
"scheduler creator specified. Bypassing "
"batcher. Request signature is: "
<< TensorSignatureDebugString(signature);
// Because the wrapped session may not provide an implementation for
// thread_pool_options, we need to invoke different Run() functions
// depending on whether thread_pool_options is specified.
if (thread_pool_options) {
return wrapped_->Run(run_options, inputs, output_tensor_names,
target_node_names, outputs, run_metadata,
thread_pool_options.value());
} else {
return wrapped_->Run(run_options, inputs, output_tensor_names,
target_node_names, outputs, run_metadata);
}
}
}
BatchScheduler<BatchingSessionTask>* batch_scheduler =
batch_scheduler_it->second.get();
outputs->clear();
Notification done;
Status status;
auto task = std::unique_ptr<BatchingSessionTask>(new BatchingSessionTask);
task->enqueue_time_micros = EnvTime::NowMicros();
task->run_options = run_options;
TF_RETURN_IF_ERROR(ComputeInputSize(inputs, &task->zeroth_dim_size));
task->inputs = &inputs;
task->output_tensor_names = &output_tensor_names;
task->done = &done;
task->status = &status;
task->outputs = outputs;
task->run_metadata = run_metadata;
task->thread_pool_options = thread_pool_options;
task->thread_safe_status = std::make_shared<ThreadSafeStatus>();
task->shared_outputs = std::make_shared<std::vector<std::vector<Tensor>>>();
task->split_run_metadatas = absl::make_unique<std::vector<RunMetadata>>();
TF_RETURN_IF_ERROR(batch_scheduler->Schedule(&task));
done.WaitForNotification();
return status;
}