in tensorflow_serving/servables/tensorflow/tflite_session.cc [300:385]
Status TfLiteSession::SplitTfLiteInputTask(
std::unique_ptr<TfLiteBatchTask>* input_task_ptr,
int open_batch_remaining_slot, int max_batch_size,
std::vector<std::unique_ptr<TfLiteBatchTask>>* output_tasks) {
auto* input_task = input_task_ptr->get();
auto split_output =
std::make_shared<std::vector<std::unique_ptr<std::vector<Tensor>>>>();
auto partial_status = std::make_shared<ThreadSafeStatus>();
auto split_task_done_callback = [split_output, partial_status, input_task]() {
// notify the input task.
auto cleanup = gtl::MakeCleanup([done_notification = input_task->done]() {
done_notification->Notify();
});
// partial status is set during actual running.
if (!partial_status->status().ok()) {
*input_task->status = partial_status->status();
return;
}
// get the total number of tensors to concatenate (number of tasks)
int output_size = split_output->size();
// each split contains the same number of output tensors.
int tensor_size = (*split_output)[0]->size();
// for each tensor output
for (int tensor_idx = 0; tensor_idx < tensor_size; ++tensor_idx) {
Tensor output_tensor; // the concatened tensor
std::vector<Tensor> to_concatenate;
to_concatenate.reserve(output_size);
// for each split task concatenate the output
for (int output_idx = 0; output_idx < output_size; ++output_idx) {
to_concatenate.push_back(
std::move((*(*split_output)[output_idx])[tensor_idx]));
}
const auto concat_status = tensor::Concat(to_concatenate, &output_tensor);
if (!concat_status.ok()) {
*input_task->status = concat_status;
return;
}
// add the concatenated tensor to input_tasks output
input_task->outputs->push_back(output_tensor);
}
*input_task->status = Status::OK();
};
// The Callback will be run only after all partial tasks finished.
IncrementalBarrier barrier(std::move(split_task_done_callback));
std::vector<int64_t> output_task_sizes;
if (open_batch_remaining_slot > 0) {
output_task_sizes.push_back(open_batch_remaining_slot);
split_output->emplace_back(absl::make_unique<std::vector<Tensor>>());
}
for (int left_task_size = input_task->size() - open_batch_remaining_slot;
left_task_size > 0; left_task_size -= max_batch_size) {
int next_task_size = std::min(left_task_size, max_batch_size);
output_task_sizes.push_back(next_task_size);
split_output->emplace_back(absl::make_unique<std::vector<Tensor>>());
}
const int output_task_num = output_task_sizes.size();
output_tasks->reserve(output_task_num);
for (int i = 0; i < output_task_num; ++i) {
std::unique_ptr<TfLiteBatchTask> task;
TfLiteBatchTask::CreatePartialTfLiteBatchTask(
input_task->input_indices, input_task->output_tensor_names,
(*split_output)[i].get(), barrier.Inc(), partial_status.get(), &task);
output_tasks->push_back(std::move(task));
}
for (int i = 0; i < input_task->inputs.size(); ++i) {
const Tensor& input = input_task->inputs[i];
std::vector<Tensor> split_tensors;
auto status = tensor::Split(input, output_task_sizes, &split_tensors);
if (status != Status::OK()) {
return status;
}
for (int output_idx = 0; output_idx < output_task_num; ++output_idx) {
auto& output_task = (*output_tasks)[output_idx];
output_task->inputs.push_back(std::move(split_tensors[output_idx]));
}
}
return Status::OK();
}