in tensorflow_serving/batching/batching_session.cc [578:663]
Status BatchingSession::SplitOutputTensors(
const TensorSignature& signature,
const std::vector<Tensor>& combined_outputs,
Batch<BatchingSessionTask>* batch) {
DCHECK_GE(batch->num_tasks(), 1);
if (batch->num_tasks() < 1) {
return errors::Internal("Batch size expected to be positive; was ",
batch->num_tasks());
}
std::vector<int64_t> task_sizes_plus_optional_padding;
task_sizes_plus_optional_padding.reserve(batch->num_tasks());
for (int i = 0; i < batch->num_tasks(); ++i) {
task_sizes_plus_optional_padding.push_back(batch->task(i).zeroth_dim_size);
}
const int padding_size = RoundToLowestAllowedBatchSize(
options_.allowed_batch_sizes, batch->size()) -
batch->size();
if (padding_size > 0) {
task_sizes_plus_optional_padding.push_back(padding_size);
}
// For each output tensor name, a divided-up tensor with one entry per task.
std::map<string, std::vector<Tensor>> split_tensors;
// Populate 'split_tensors'.
DCHECK_EQ(signature.output_tensors.size(), combined_outputs.size());
if (combined_outputs.size() != signature.output_tensors.size()) {
return errors::Internal("Wrong number of batched output tensors");
}
const std::vector<string> output_tensors(signature.output_tensors.begin(),
signature.output_tensors.end());
for (int i = 0; i < output_tensors.size(); ++i) {
const string& tensor_name = output_tensors[i];
const Tensor& tensor = combined_outputs[i];
if (tensor.shape().dims() == 0) {
return errors::FailedPrecondition(
"Batched output tensor has 0 dimensions");
}
if (tensor.shape().dim_size(0) != batch->size() + padding_size) {
return errors::FailedPrecondition(
"Batched output tensor's 0th dimension does not equal the sum of the "
"0th dimension sizes of the input tensors");
}
std::vector<Tensor> split_tensor;
const Status split_status =
tensor::Split(tensor, task_sizes_plus_optional_padding, &split_tensor);
DCHECK(split_status.ok()) << split_status.ToString();
if (!split_status.ok()) {
return errors::Internal("Tensor split operation failed: ",
split_status.ToString());
}
DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
return errors::Internal(
"Tensor split operation did not work as expected; got ",
split_tensor.size(), " splits; expected ",
task_sizes_plus_optional_padding.size());
}
split_tensors[tensor_name] = std::move(split_tensor);
}
for (int i = 0; i < batch->num_tasks(); ++i) {
BatchingSessionTask* task = batch->mutable_task(i);
for (const string& tensor_name : *task->output_tensor_names) {
auto split_tensor = split_tensors.find(tensor_name);
DCHECK(split_tensor != split_tensors.end());
if (split_tensor == split_tensors.end()) {
return errors::Internal("Task does not conform to batch signature");
}
if (task->is_partial) {
std::vector<Tensor>& tensor_vector =
(*task->shared_outputs)[task->split_index];
tensor_vector.push_back(std::move(split_tensor->second[i]));
} else {
task->outputs->push_back(std::move(split_tensor->second[i]));
}
}
}
// (Ignore a possible final split_tensors entry containing the padding.)
return Status::OK();
}