Status SplitInputTask()

in tensorflow_serving/batching/batching_session.cc [807:957]


Status SplitInputTask(
    std::unique_ptr<BatchingSessionTask>* input_task_ptr,
    int open_batch_remaining_slot, int max_batch_size,
    std::vector<std::unique_ptr<BatchingSessionTask>>* output_tasks) {
  BatchingSessionTask& input_task = *(*input_task_ptr);
  const int64_t input_task_size = input_task.size();

  DCHECK_GT(input_task_size, 0);

  // `split_task_done_callback` runs only after all split tasks are complete.
  std::function<void()> split_task_done_callback =
      [done_notification = input_task.done,
       shared_outputs = input_task.shared_outputs,
       shared_status = input_task.thread_safe_status,
       num_output = input_task.output_tensor_names->size(),
       outputs = input_task.outputs, status = input_task.status,
       run_metadata = input_task.run_metadata,
       split_run_metadatas = input_task.split_run_metadatas]() {
        auto finally = gtl::MakeCleanup([&] {
          *status = shared_status->status();
          done_notification->Notify();
        });

        // Some slices of tasks encounter errors, return early without
        // processing per-split result.
        if (!shared_status->status().ok()) {
          return;
        }

        for (int i = 0; i < num_output; ++i) {
          Tensor output_tensor;

          // Concat i-th tensor from each split into i-th tensor of output.
          std::vector<Tensor> to_concatenate;
          to_concatenate.reserve(shared_outputs->size());
          for (int j = 0; j < shared_outputs->size(); ++j) {
            to_concatenate.push_back(std::move((*shared_outputs)[j][i]));
          }
          const auto concat_status =
              tensor::Concat(to_concatenate, &output_tensor);
          if (!concat_status.ok()) {
            shared_status->Update(concat_status);
            return;
          }

          outputs->push_back(std::move(output_tensor));
        }

        // `cost_dimension_map` aggregates costs from all splits for each
        // dimension.
        absl::flat_hash_map<string, float> cost_dimension_map;
        for (const auto& split : *split_run_metadatas) {
          if (split.has_cost_graph()) {
            for (const auto& cost : split.cost_graph().cost()) {
              cost_dimension_map[cost.dimension()] += cost.cost();
            }
          }
        }

        *run_metadata = (*split_run_metadatas)[0];
        std::vector<string> cost_dimensions;
        for (const auto& cost_and_dimension :
             run_metadata->cost_graph().cost()) {
          cost_dimensions.push_back(cost_and_dimension.dimension());
        }
        run_metadata->mutable_cost_graph()->clear_cost();
        for (const auto& dimension : cost_dimensions) {
          const auto iter = cost_dimension_map.find(dimension);
          if (iter != cost_dimension_map.end()) {
            auto graph_cost = run_metadata->mutable_cost_graph()->add_cost();
            graph_cost->set_dimension(iter->first);
            graph_cost->set_cost(iter->second);
          }
        }
      };
  IncrementalBarrier barrier(split_task_done_callback);

  const internal::InputSplitMetadata input_split_metadata(
      input_task_size, open_batch_remaining_slot, max_batch_size);

  // Creates an array of int64_t from an array of int, since `tensor::Split`
  // requires an array of int64.
  const absl::FixedArray<int64_t> output_task_sizes(
      input_split_metadata.task_sizes().begin(),
      input_split_metadata.task_sizes().end());
  const int num_batches = output_task_sizes.size();

  input_task.shared_outputs->resize(num_batches);

  for (int i = 0; i < num_batches; ++i) {
    (*input_task.shared_outputs)[i].reserve(
        input_task.output_tensor_names->size());
  }

  input_task.split_run_metadatas->resize(num_batches);

  output_tasks->reserve(num_batches);
  for (int i = 0; i < num_batches; i++) {
    auto task = absl::make_unique<BatchingSessionTask>();
    task->enqueue_time_micros = input_task.enqueue_time_micros;
    task->run_options = input_task.run_options;
    task->zeroth_dim_size = output_task_sizes[i];
    // `task->owned_input` will be initialized separately out of this for-loop.
    task->output_tensor_names = input_task.output_tensor_names;

    task->owned_split_inputs =
        absl::make_unique<std::vector<std::pair<string, Tensor>>>();
    task->split_index = i;
    task->shared_outputs = input_task.shared_outputs;
    task->thread_safe_status = input_task.thread_safe_status;
    task->is_partial = true;
    task->done_callback = barrier.Inc();
    task->thread_pool_options = input_task.thread_pool_options;

    task->split_run_metadatas = input_task.split_run_metadatas;

    output_tasks->push_back(std::move(task));
  }

  const int num_input_tensors = input_task.inputs->size();

  // Splits each input tensor according to `output_task_sizes`, and
  // initializes input of `output_tasks` with split results.
  for (int i = 0; i < num_input_tensors; ++i) {
    std::vector<Tensor> split_tensors;
    const string& tensor_name = (*input_task.inputs)[i].first;
    const Tensor& input_tensor = (*input_task.inputs)[i].second;
    // TODO(b/158393551):
    // Figure out the optimal implementation of Split, by using
    // 'Tensor::Slice' and eliminating unnecessary memcpy as much as possible.
    const Status split_status =
        tensor::Split(input_tensor, output_task_sizes, &split_tensors);
    if (!split_status.ok()) {
      return errors::Internal(
          "When splitting input, Tensor split operation failed: ",
          split_status.ToString());
    }
    if (split_tensors.size() != output_task_sizes.size()) {
      return errors::Internal(
          "When splitting input, tensor split operation did not work as "
          "expected; got ",
          split_tensors.size(), " splits; expected ", output_task_sizes.size());
    }
    for (int j = 0; j < output_tasks->size(); ++j) {
      BatchingSessionTask& output_task = *((*output_tasks)[j]);
      output_task.owned_split_inputs->push_back(
          std::make_pair(tensor_name, split_tensors[j]));
    }
  }
  return Status::OK();
}