in tensorflow_serving/batching/batching_session.cc [472:576]
Status BatchingSession::MergeInputTensors(
const TensorSignature& signature, const Batch<BatchingSessionTask>& batch,
std::vector<std::pair<string, Tensor>>* merged_inputs) {
DCHECK_GE(batch.num_tasks(), 1);
if (batch.num_tasks() < 1) {
return errors::Internal("Batch size expected to be positive; was ",
batch.num_tasks());
}
const int lowest_allowed_batch_size =
RoundToLowestAllowedBatchSize(options_.allowed_batch_sizes, batch.size());
const int padding_size = lowest_allowed_batch_size - batch.size();
profiler::TraceMe trace_me([lowest_allowed_batch_size, padding_size]() {
return profiler::TraceMeEncode(
"MergeInputTensors",
{{"batch_size_after_padding", lowest_allowed_batch_size},
{"padding_amount", padding_size}});
});
RecordPaddingSize<BatchingSessionTask>(padding_size,
lowest_allowed_batch_size);
RecordProcessedBatchSize<BatchingSessionTask>(lowest_allowed_batch_size);
// For each input tensor name, a vector of tensors from the individual tasks.
std::map<string, std::vector<Tensor>> tensors_to_merge;
// For each input tensor name a vector of maximum dimension sizes
// among tensors from individual tasks.
absl::optional<std::map<string, std::vector<int>>> max_dim_sizes;
if (options_.pad_variable_length_inputs) {
std::vector<std::vector<std::pair<string, Tensor>>> all_task_inputs =
GetTaskInputsVector(batch);
max_dim_sizes = CalculateMaxDimSizes(all_task_inputs);
}
// Populate 'tensors_to_merge'.
for (int i = 0; i < batch.num_tasks(); ++i) {
const std::vector<std::pair<string, Tensor>>& task_inputs =
GetTaskInput(batch.task(i));
for (const auto& entry : task_inputs) {
const string& tensor_name = entry.first;
const Tensor& tensor = entry.second;
std::vector<Tensor>& tensor_vec = tensors_to_merge[tensor_name];
Tensor optionally_padded_tensor;
if (options_.pad_variable_length_inputs) {
TF_RETURN_IF_ERROR(AddPadding(tensor, (*max_dim_sizes)[tensor_name],
&optionally_padded_tensor));
} else {
optionally_padded_tensor = tensor;
// Check whether tensors with the same name have equal dims
// (except zeroth dim) when padding is turned off.
if (i > 0) { // added at least one task to tensors_to_merge
TensorShape reference_shape =
tensors_to_merge[tensor_name][0].shape();
if (!AreShapesEqualExceptZeroDim(tensor.shape(), reference_shape)) {
return errors::FailedPrecondition(
"Tensors with name '" + tensor_name +
"' from different tasks have different shapes and padding is "
"turned off. Set pad_variable_length_inputs to true, or ensure "
"that all tensors with the same name have equal dimensions "
"starting with the first dim.");
}
}
}
tensor_vec.push_back(std::move(optionally_padded_tensor));
if (i == batch.num_tasks() - 1 && padding_size > 0) {
// This is the last task. Insert padding.
//
// Use the first row of this task's tensor as the padding data. (We know
// it represents a valid input tensor row, so it should always be safe
// to use for padding.)
//
// Slice() operates on the 0th dimension, which is the batch dimension.
// It avoids a deep copy, which is a nice efficiency bonus.
const Tensor padding_tensor = tensor_vec.back().Slice(0, 1);
for (int i = 0; i < padding_size; ++i) {
tensor_vec.push_back(padding_tensor);
}
}
}
}
// Merge the tensors.
DCHECK_EQ(signature.input_tensors.size(), tensors_to_merge.size());
if (tensors_to_merge.size() != signature.input_tensors.size()) {
return errors::Internal(
"One or more tasks does not conform to batch signature");
}
for (const string& tensor_name : signature.input_tensors) {
auto tensors = tensors_to_merge.find(tensor_name);
DCHECK(tensors != tensors_to_merge.end());
if (tensors == tensors_to_merge.end()) {
return errors::Internal(
"One or more tasks does not conform to batch signature");
}
Tensor concated;
const Status concat_status = tensor::Concat(tensors->second, &concated);
DCHECK(concat_status.ok()) << concat_status.ToString();
if (!concat_status.ok()) {
return errors::Internal("Tensor concat operation failed: ",
concat_status.ToString());
}
merged_inputs->push_back({tensor_name, std::move(concated)});
}
return Status::OK();
}