Status BatchingSession::InternalRun()

in tensorflow_serving/batching/batching_session.cc [363:445]


Status BatchingSession::InternalRun(
    const RunOptions& run_options,
    const std::vector<std::pair<string, Tensor>>& inputs,
    const std::vector<string>& output_tensor_names,
    const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
    RunMetadata* run_metadata,
    absl::optional<thread::ThreadPoolOptions> thread_pool_options) {
  if (!target_node_names.empty()) {
    return errors::PermissionDenied(
        "BatchingSession does not support target nodes");
  }

  profiler::TraceMe trace_me([this] {
    return profiler::TraceMeEncode(
        "BatchingSessionRun",
        {{"thread_pool_name", thread_pool_name_}, {"_r", 1} /*root_event*/});
  });
  const TensorSignature signature =
      TensorSignatureFromRunArgs(inputs, output_tensor_names);
  auto batch_scheduler_it = batch_schedulers_.find(signature);
  if (batch_scheduler_it == batch_schedulers_.end()) {
    if (default_scheduler_creator_.has_value()) {
      absl::MutexLock l(&mu_);
      batch_scheduler_it = custom_signature_batch_schedulers_.find(signature);
      if (batch_scheduler_it == custom_signature_batch_schedulers_.end()) {
        std::unique_ptr<BatchScheduler<BatchingSessionTask>> batch_scheduler;
        TF_RETURN_IF_ERROR(default_scheduler_creator_.value()(
            [&, signature](std::unique_ptr<Batch<BatchingSessionTask>> batch) {
              ProcessBatch(signature, std::move(batch));
            },
            &batch_scheduler));
        custom_signature_batch_schedulers_[signature] =
            std::move(batch_scheduler);
        batch_scheduler_it = custom_signature_batch_schedulers_.find(signature);
      }
    } else {
      // We have a Run() call that doesn't match one of our batching signatures.
      // Run it in-line.
      LOG_EVERY_N_SEC(WARNING, 120)
          << "Request doesn't match any declared signature and no default "
             "scheduler creator specified. Bypassing "
             "batcher. Request signature is: "
          << TensorSignatureDebugString(signature);

      // Because the wrapped session may not provide an implementation for
      // thread_pool_options, we need to invoke different Run() functions
      // depending on whether thread_pool_options is specified.
      if (thread_pool_options) {
        return wrapped_->Run(run_options, inputs, output_tensor_names,
                             target_node_names, outputs, run_metadata,
                             thread_pool_options.value());
      } else {
        return wrapped_->Run(run_options, inputs, output_tensor_names,
                             target_node_names, outputs, run_metadata);
      }
    }
  }
  BatchScheduler<BatchingSessionTask>* batch_scheduler =
      batch_scheduler_it->second.get();

  outputs->clear();

  Notification done;
  Status status;
  auto task = std::unique_ptr<BatchingSessionTask>(new BatchingSessionTask);
  task->enqueue_time_micros = EnvTime::NowMicros();
  task->run_options = run_options;
  TF_RETURN_IF_ERROR(ComputeInputSize(inputs, &task->zeroth_dim_size));
  task->inputs = &inputs;
  task->output_tensor_names = &output_tensor_names;
  task->done = &done;
  task->status = &status;
  task->outputs = outputs;
  task->run_metadata = run_metadata;
  task->thread_pool_options = thread_pool_options;
  task->thread_safe_status = std::make_shared<ThreadSafeStatus>();
  task->shared_outputs = std::make_shared<std::vector<std::vector<Tensor>>>();
  task->split_run_metadatas = absl::make_unique<std::vector<RunMetadata>>();

  TF_RETURN_IF_ERROR(batch_scheduler->Schedule(&task));
  done.WaitForNotification();
  return status;
}