Status BatchingSession::MergeInputTensors()

in tensorflow_serving/batching/batching_session.cc [472:576]


Status BatchingSession::MergeInputTensors(
    const TensorSignature& signature, const Batch<BatchingSessionTask>& batch,
    std::vector<std::pair<string, Tensor>>* merged_inputs) {
  DCHECK_GE(batch.num_tasks(), 1);
  if (batch.num_tasks() < 1) {
    return errors::Internal("Batch size expected to be positive; was ",
                            batch.num_tasks());
  }

  const int lowest_allowed_batch_size =
      RoundToLowestAllowedBatchSize(options_.allowed_batch_sizes, batch.size());
  const int padding_size = lowest_allowed_batch_size - batch.size();
  profiler::TraceMe trace_me([lowest_allowed_batch_size, padding_size]() {
    return profiler::TraceMeEncode(
        "MergeInputTensors",
        {{"batch_size_after_padding", lowest_allowed_batch_size},
         {"padding_amount", padding_size}});
  });
  RecordPaddingSize<BatchingSessionTask>(padding_size,
                                         lowest_allowed_batch_size);
  RecordProcessedBatchSize<BatchingSessionTask>(lowest_allowed_batch_size);

  // For each input tensor name, a vector of tensors from the individual tasks.
  std::map<string, std::vector<Tensor>> tensors_to_merge;
  // For each input tensor name a vector of maximum dimension sizes
  // among tensors from individual tasks.
  absl::optional<std::map<string, std::vector<int>>> max_dim_sizes;
  if (options_.pad_variable_length_inputs) {
    std::vector<std::vector<std::pair<string, Tensor>>> all_task_inputs =
        GetTaskInputsVector(batch);
    max_dim_sizes = CalculateMaxDimSizes(all_task_inputs);
  }
  // Populate 'tensors_to_merge'.
  for (int i = 0; i < batch.num_tasks(); ++i) {
    const std::vector<std::pair<string, Tensor>>& task_inputs =
        GetTaskInput(batch.task(i));
    for (const auto& entry : task_inputs) {
      const string& tensor_name = entry.first;
      const Tensor& tensor = entry.second;

      std::vector<Tensor>& tensor_vec = tensors_to_merge[tensor_name];
      Tensor optionally_padded_tensor;
      if (options_.pad_variable_length_inputs) {
        TF_RETURN_IF_ERROR(AddPadding(tensor, (*max_dim_sizes)[tensor_name],
                                      &optionally_padded_tensor));
      } else {
        optionally_padded_tensor = tensor;
        // Check whether tensors with the same name have equal dims
        // (except zeroth dim) when padding is turned off.
        if (i > 0) {  // added at least one task to tensors_to_merge
          TensorShape reference_shape =
              tensors_to_merge[tensor_name][0].shape();
          if (!AreShapesEqualExceptZeroDim(tensor.shape(), reference_shape)) {
            return errors::FailedPrecondition(
                "Tensors with name '" + tensor_name +
                "' from different tasks have different shapes and padding is "
                "turned off. Set pad_variable_length_inputs to true, or ensure "
                "that all tensors with the same name have equal dimensions "
                "starting with the first dim.");
          }
        }
      }
      tensor_vec.push_back(std::move(optionally_padded_tensor));
      if (i == batch.num_tasks() - 1 && padding_size > 0) {
        // This is the last task. Insert padding.
        //
        // Use the first row of this task's tensor as the padding data. (We know
        // it represents a valid input tensor row, so it should always be safe
        // to use for padding.)
        //
        // Slice() operates on the 0th dimension, which is the batch dimension.
        // It avoids a deep copy, which is a nice efficiency bonus.
        const Tensor padding_tensor = tensor_vec.back().Slice(0, 1);
        for (int i = 0; i < padding_size; ++i) {
          tensor_vec.push_back(padding_tensor);
        }
      }
    }
  }

  // Merge the tensors.
  DCHECK_EQ(signature.input_tensors.size(), tensors_to_merge.size());
  if (tensors_to_merge.size() != signature.input_tensors.size()) {
    return errors::Internal(
        "One or more tasks does not conform to batch signature");
  }
  for (const string& tensor_name : signature.input_tensors) {
    auto tensors = tensors_to_merge.find(tensor_name);
    DCHECK(tensors != tensors_to_merge.end());
    if (tensors == tensors_to_merge.end()) {
      return errors::Internal(
          "One or more tasks does not conform to batch signature");
    }
    Tensor concated;
    const Status concat_status = tensor::Concat(tensors->second, &concated);
    DCHECK(concat_status.ok()) << concat_status.ToString();
    if (!concat_status.ok()) {
      return errors::Internal("Tensor concat operation failed: ",
                              concat_status.ToString());
    }
    merged_inputs->push_back({tensor_name, std::move(concated)});
  }

  return Status::OK();
}