in tensorflow_serving/batching/batching_session.cc [700:801]
void BatchingSession::ProcessBatch(
const TensorSignature& signature,
std::unique_ptr<Batch<BatchingSessionTask>> batch) {
// As a possible performance optimization, consider overlapping the tensor
// concatenation with waiting for the batch to close (i.e. do the
// concatenation incrementally as tasks stream into the batch).
batch->WaitUntilClosed();
if (batch->empty()) {
return;
}
const uint64_t dequeue_time_micros = EnvTime::NowMicros();
// Regardless of the outcome, we need to propagate the status to the
// individual tasks and signal that they are done. We use MakeCleanup() to
// ensure that this happens no matter how we exit the method below.
Status status;
auto finally = gtl::MakeCleanup([&status, &batch] {
for (int i = 0; i < batch->num_tasks(); ++i) {
BatchingSessionTask* task = batch->mutable_task(i);
if (task->is_partial) {
task->thread_safe_status->Update(status);
task->done_callback();
} else {
*batch->mutable_task(i)->status = status;
batch->mutable_task(i)->done->Notify();
}
}
});
// Make sure we have at least one task that hasn't exceeded its timeout from
// queue time alone, and find the latest task deadline which we'll use for the
// overall batch.
bool all_tasks_timeout_exceeded = true;
uint64_t batch_deadline_micros = 0;
for (int i = 0; i < batch->num_tasks(); ++i) {
const BatchingSessionTask& task = batch->task(i);
// If the caller doesn't populate RunOptions, the timeout is 0 by default.
// Interpret that as "no timeout" i.e. infinity.
const int64_t task_timeout_micros =
task.run_options.timeout_in_ms() <= 0
? INT_MAX
: task.run_options.timeout_in_ms() * 1000;
const uint64_t task_deadline_micros =
task.enqueue_time_micros + task_timeout_micros;
if (task_deadline_micros > dequeue_time_micros) {
all_tasks_timeout_exceeded = false;
if (task_deadline_micros > batch_deadline_micros) {
batch_deadline_micros = task_deadline_micros;
}
}
queuing_latency->GetCell(thread_pool_name_)
->Add(dequeue_time_micros - task.enqueue_time_micros);
}
if (all_tasks_timeout_exceeded) {
status = Status(error::RESOURCE_EXHAUSTED,
"Run() timeout exceeded while waiting in batching queue");
return;
}
RunOptions run_options = batch->task(0).run_options;
if (batch_deadline_micros == INT_MAX) {
run_options.set_timeout_in_ms(0);
} else {
run_options.set_timeout_in_ms(
(batch_deadline_micros - dequeue_time_micros) / 1000);
}
std::vector<std::pair<string, Tensor>> merged_inputs;
status = MergeInputTensors(signature, *batch, &merged_inputs);
if (!status.ok()) {
return;
}
absl::optional<thread::ThreadPoolOptions> thread_pool_options =
batch->task(0).thread_pool_options;
const std::vector<string> output_tensor_names(
signature.output_tensors.begin(), signature.output_tensors.end());
std::vector<Tensor> combined_outputs;
RunMetadata run_metadata;
// Because the wrapped session may not provide an implementation for
// thread_pool_options, we need to invoke different Run() functions depending
// on whether thread_pool_options is specified.
if (thread_pool_options) {
status = wrapped_->Run(run_options, merged_inputs, output_tensor_names,
{} /* target node names */, &combined_outputs,
&run_metadata, thread_pool_options.value());
} else {
status = wrapped_->Run(run_options, merged_inputs, output_tensor_names,
{} /* target node names */, &combined_outputs,
&run_metadata);
}
status.Update(SplitRunMetadata(&run_metadata, batch.get()));
if (!status.ok()) {
return;
}
status = SplitOutputTensors(signature, combined_outputs, batch.get());
}