in tensorflow_serving/batching/batching_session.cc [807:957]
Status SplitInputTask(
std::unique_ptr<BatchingSessionTask>* input_task_ptr,
int open_batch_remaining_slot, int max_batch_size,
std::vector<std::unique_ptr<BatchingSessionTask>>* output_tasks) {
BatchingSessionTask& input_task = *(*input_task_ptr);
const int64_t input_task_size = input_task.size();
DCHECK_GT(input_task_size, 0);
// `split_task_done_callback` runs only after all split tasks are complete.
std::function<void()> split_task_done_callback =
[done_notification = input_task.done,
shared_outputs = input_task.shared_outputs,
shared_status = input_task.thread_safe_status,
num_output = input_task.output_tensor_names->size(),
outputs = input_task.outputs, status = input_task.status,
run_metadata = input_task.run_metadata,
split_run_metadatas = input_task.split_run_metadatas]() {
auto finally = gtl::MakeCleanup([&] {
*status = shared_status->status();
done_notification->Notify();
});
// Some slices of tasks encounter errors, return early without
// processing per-split result.
if (!shared_status->status().ok()) {
return;
}
for (int i = 0; i < num_output; ++i) {
Tensor output_tensor;
// Concat i-th tensor from each split into i-th tensor of output.
std::vector<Tensor> to_concatenate;
to_concatenate.reserve(shared_outputs->size());
for (int j = 0; j < shared_outputs->size(); ++j) {
to_concatenate.push_back(std::move((*shared_outputs)[j][i]));
}
const auto concat_status =
tensor::Concat(to_concatenate, &output_tensor);
if (!concat_status.ok()) {
shared_status->Update(concat_status);
return;
}
outputs->push_back(std::move(output_tensor));
}
// `cost_dimension_map` aggregates costs from all splits for each
// dimension.
absl::flat_hash_map<string, float> cost_dimension_map;
for (const auto& split : *split_run_metadatas) {
if (split.has_cost_graph()) {
for (const auto& cost : split.cost_graph().cost()) {
cost_dimension_map[cost.dimension()] += cost.cost();
}
}
}
*run_metadata = (*split_run_metadatas)[0];
std::vector<string> cost_dimensions;
for (const auto& cost_and_dimension :
run_metadata->cost_graph().cost()) {
cost_dimensions.push_back(cost_and_dimension.dimension());
}
run_metadata->mutable_cost_graph()->clear_cost();
for (const auto& dimension : cost_dimensions) {
const auto iter = cost_dimension_map.find(dimension);
if (iter != cost_dimension_map.end()) {
auto graph_cost = run_metadata->mutable_cost_graph()->add_cost();
graph_cost->set_dimension(iter->first);
graph_cost->set_cost(iter->second);
}
}
};
IncrementalBarrier barrier(split_task_done_callback);
const internal::InputSplitMetadata input_split_metadata(
input_task_size, open_batch_remaining_slot, max_batch_size);
// Creates an array of int64_t from an array of int, since `tensor::Split`
// requires an array of int64.
const absl::FixedArray<int64_t> output_task_sizes(
input_split_metadata.task_sizes().begin(),
input_split_metadata.task_sizes().end());
const int num_batches = output_task_sizes.size();
input_task.shared_outputs->resize(num_batches);
for (int i = 0; i < num_batches; ++i) {
(*input_task.shared_outputs)[i].reserve(
input_task.output_tensor_names->size());
}
input_task.split_run_metadatas->resize(num_batches);
output_tasks->reserve(num_batches);
for (int i = 0; i < num_batches; i++) {
auto task = absl::make_unique<BatchingSessionTask>();
task->enqueue_time_micros = input_task.enqueue_time_micros;
task->run_options = input_task.run_options;
task->zeroth_dim_size = output_task_sizes[i];
// `task->owned_input` will be initialized separately out of this for-loop.
task->output_tensor_names = input_task.output_tensor_names;
task->owned_split_inputs =
absl::make_unique<std::vector<std::pair<string, Tensor>>>();
task->split_index = i;
task->shared_outputs = input_task.shared_outputs;
task->thread_safe_status = input_task.thread_safe_status;
task->is_partial = true;
task->done_callback = barrier.Inc();
task->thread_pool_options = input_task.thread_pool_options;
task->split_run_metadatas = input_task.split_run_metadatas;
output_tasks->push_back(std::move(task));
}
const int num_input_tensors = input_task.inputs->size();
// Splits each input tensor according to `output_task_sizes`, and
// initializes input of `output_tasks` with split results.
for (int i = 0; i < num_input_tensors; ++i) {
std::vector<Tensor> split_tensors;
const string& tensor_name = (*input_task.inputs)[i].first;
const Tensor& input_tensor = (*input_task.inputs)[i].second;
// TODO(b/158393551):
// Figure out the optimal implementation of Split, by using
// 'Tensor::Slice' and eliminating unnecessary memcpy as much as possible.
const Status split_status =
tensor::Split(input_tensor, output_task_sizes, &split_tensors);
if (!split_status.ok()) {
return errors::Internal(
"When splitting input, Tensor split operation failed: ",
split_status.ToString());
}
if (split_tensors.size() != output_task_sizes.size()) {
return errors::Internal(
"When splitting input, tensor split operation did not work as "
"expected; got ",
split_tensors.size(), " splits; expected ", output_task_sizes.size());
}
for (int j = 0; j < output_tasks->size(); ++j) {
BatchingSessionTask& output_task = *((*output_tasks)[j]);
output_task.owned_split_inputs->push_back(
std::make_pair(tensor_name, split_tensors[j]));
}
}
return Status::OK();
}