Status TfLiteSession::SplitTfLiteInputTask()

in tensorflow_serving/servables/tensorflow/tflite_session.cc [300:385]


Status TfLiteSession::SplitTfLiteInputTask(
    std::unique_ptr<TfLiteBatchTask>* input_task_ptr,
    int open_batch_remaining_slot, int max_batch_size,
    std::vector<std::unique_ptr<TfLiteBatchTask>>* output_tasks) {
  auto* input_task = input_task_ptr->get();
  auto split_output =
      std::make_shared<std::vector<std::unique_ptr<std::vector<Tensor>>>>();
  auto partial_status = std::make_shared<ThreadSafeStatus>();
  auto split_task_done_callback = [split_output, partial_status, input_task]() {
    // notify the input task.
    auto cleanup = gtl::MakeCleanup([done_notification = input_task->done]() {
      done_notification->Notify();
    });

    // partial status is set during actual running.
    if (!partial_status->status().ok()) {
      *input_task->status = partial_status->status();
      return;
    }

    // get the total number of tensors to concatenate (number of tasks)
    int output_size = split_output->size();
    // each split contains the same number of output tensors.
    int tensor_size = (*split_output)[0]->size();

    // for each tensor output
    for (int tensor_idx = 0; tensor_idx < tensor_size; ++tensor_idx) {
      Tensor output_tensor;  // the concatened tensor
      std::vector<Tensor> to_concatenate;
      to_concatenate.reserve(output_size);
      // for each split task concatenate the output
      for (int output_idx = 0; output_idx < output_size; ++output_idx) {
        to_concatenate.push_back(
            std::move((*(*split_output)[output_idx])[tensor_idx]));
      }
      const auto concat_status = tensor::Concat(to_concatenate, &output_tensor);
      if (!concat_status.ok()) {
        *input_task->status = concat_status;
        return;
      }
      // add the concatenated tensor to input_tasks output
      input_task->outputs->push_back(output_tensor);
    }
    *input_task->status = Status::OK();
  };

  // The Callback will be run only after all partial tasks finished.
  IncrementalBarrier barrier(std::move(split_task_done_callback));
  std::vector<int64_t> output_task_sizes;

  if (open_batch_remaining_slot > 0) {
    output_task_sizes.push_back(open_batch_remaining_slot);
    split_output->emplace_back(absl::make_unique<std::vector<Tensor>>());
  }

  for (int left_task_size = input_task->size() - open_batch_remaining_slot;
       left_task_size > 0; left_task_size -= max_batch_size) {
    int next_task_size = std::min(left_task_size, max_batch_size);
    output_task_sizes.push_back(next_task_size);
    split_output->emplace_back(absl::make_unique<std::vector<Tensor>>());
  }

  const int output_task_num = output_task_sizes.size();
  output_tasks->reserve(output_task_num);
  for (int i = 0; i < output_task_num; ++i) {
    std::unique_ptr<TfLiteBatchTask> task;
    TfLiteBatchTask::CreatePartialTfLiteBatchTask(
        input_task->input_indices, input_task->output_tensor_names,
        (*split_output)[i].get(), barrier.Inc(), partial_status.get(), &task);
    output_tasks->push_back(std::move(task));
  }

  for (int i = 0; i < input_task->inputs.size(); ++i) {
    const Tensor& input = input_task->inputs[i];
    std::vector<Tensor> split_tensors;
    auto status = tensor::Split(input, output_task_sizes, &split_tensors);
    if (status != Status::OK()) {
      return status;
    }
    for (int output_idx = 0; output_idx < output_task_num; ++output_idx) {
      auto& output_task = (*output_tasks)[output_idx];
      output_task->inputs.push_back(std::move(split_tensors[output_idx]));
    }
  }
  return Status::OK();
}