Status BatchingSession::SplitOutputTensors()

in tensorflow_serving/batching/batching_session.cc [578:663]


Status BatchingSession::SplitOutputTensors(
    const TensorSignature& signature,
    const std::vector<Tensor>& combined_outputs,
    Batch<BatchingSessionTask>* batch) {
  DCHECK_GE(batch->num_tasks(), 1);
  if (batch->num_tasks() < 1) {
    return errors::Internal("Batch size expected to be positive; was ",
                            batch->num_tasks());
  }

  std::vector<int64_t> task_sizes_plus_optional_padding;
  task_sizes_plus_optional_padding.reserve(batch->num_tasks());
  for (int i = 0; i < batch->num_tasks(); ++i) {
    task_sizes_plus_optional_padding.push_back(batch->task(i).zeroth_dim_size);
  }
  const int padding_size = RoundToLowestAllowedBatchSize(
                               options_.allowed_batch_sizes, batch->size()) -
                           batch->size();
  if (padding_size > 0) {
    task_sizes_plus_optional_padding.push_back(padding_size);
  }

  // For each output tensor name, a divided-up tensor with one entry per task.
  std::map<string, std::vector<Tensor>> split_tensors;

  // Populate 'split_tensors'.
  DCHECK_EQ(signature.output_tensors.size(), combined_outputs.size());
  if (combined_outputs.size() != signature.output_tensors.size()) {
    return errors::Internal("Wrong number of batched output tensors");
  }
  const std::vector<string> output_tensors(signature.output_tensors.begin(),
                                           signature.output_tensors.end());
  for (int i = 0; i < output_tensors.size(); ++i) {
    const string& tensor_name = output_tensors[i];
    const Tensor& tensor = combined_outputs[i];

    if (tensor.shape().dims() == 0) {
      return errors::FailedPrecondition(
          "Batched output tensor has 0 dimensions");
    }
    if (tensor.shape().dim_size(0) != batch->size() + padding_size) {
      return errors::FailedPrecondition(
          "Batched output tensor's 0th dimension does not equal the sum of the "
          "0th dimension sizes of the input tensors");
    }

    std::vector<Tensor> split_tensor;
    const Status split_status =
        tensor::Split(tensor, task_sizes_plus_optional_padding, &split_tensor);
    DCHECK(split_status.ok()) << split_status.ToString();
    if (!split_status.ok()) {
      return errors::Internal("Tensor split operation failed: ",
                              split_status.ToString());
    }
    DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
    if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
      return errors::Internal(
          "Tensor split operation did not work as expected; got ",
          split_tensor.size(), " splits; expected ",
          task_sizes_plus_optional_padding.size());
    }
    split_tensors[tensor_name] = std::move(split_tensor);
  }

  for (int i = 0; i < batch->num_tasks(); ++i) {
    BatchingSessionTask* task = batch->mutable_task(i);
    for (const string& tensor_name : *task->output_tensor_names) {
      auto split_tensor = split_tensors.find(tensor_name);
      DCHECK(split_tensor != split_tensors.end());
      if (split_tensor == split_tensors.end()) {
        return errors::Internal("Task does not conform to batch signature");
      }

      if (task->is_partial) {
        std::vector<Tensor>& tensor_vector =
            (*task->shared_outputs)[task->split_index];
        tensor_vector.push_back(std::move(split_tensor->second[i]));
      } else {
        task->outputs->push_back(std::move(split_tensor->second[i]));
      }
    }
  }
  // (Ignore a possible final split_tensors entry containing the padding.)

  return Status::OK();
}