std::unique_ptr TensorflowNeuropodBackend::infer_internal()

in source/neuropod/backends/tensorflow/tf_backend.cc [317:399]


std::unique_ptr<NeuropodValueMap> TensorflowNeuropodBackend::infer_internal(
    const NeuropodValueMap &inputs, const std::vector<std::string> &requested_outputs)
{
    // In TensorFlow, a callable is a way of running a subgraph given a set of inputs and
    // outputs. It's very similar to `session_->Run` except it has support for more fine-grained
    // control over tensor devices. See https://github.com/tensorflow/tensorflow/issues/5902
    // for more details.

    // Fetches and feeds for our callable
    // Note: these are ordered maps to make it easy to cache callables
    // Map from an output node_name to an output_name
    std::map<std::string, std::string> tensor_fetches;

    // Map from an input node_name to a Tensor
    std::map<std::string, tensorflow::Tensor> tensor_feeds;

    // Get the set of outputs we want to compute
    const auto &output_names = !requested_outputs.empty() ? requested_outputs : output_names_;

    // Transform neuropod output names to node names in the graph
    for (const auto &name : output_names)
    {
        const auto node_name = node_name_mapping_.find(name);
        if (node_name == node_name_mapping_.end())
        {
            NEUROPOD_ERROR("Node {} not found in node_name_mapping. "
                           "Ensure that all items in the input/output spec have a corresponding item "
                           "in the node_name_mapping.",
                           name);
        }

        // Add this node name as an output of the subgraph we want to run
        tensor_fetches.emplace(std::make_pair(node_name->second, name));
    }

    // Loop through all the input tensors and setup the inputs
    for (const auto &entry : inputs)
    {
        const auto node_name = node_name_mapping_.find(entry.first);
        if (node_name == node_name_mapping_.end())
        {
            NEUROPOD_ERROR("Node {} not found in node_name_mapping. "
                           "Ensure that all items in the input/output spec have a corresponding item "
                           "in the node_name_mapping.",
                           entry.first);
        }

        // Get the TensorFlow tensor from the Neuropod tensor
        const auto &input_data =
            std::dynamic_pointer_cast<NativeDataContainer<tensorflow::Tensor &>>(entry.second)->get_native_data();

        // Add this node name as an input to the subgraph we want to run
        tensor_feeds.emplace(std::make_pair(node_name->second, input_data));
    }

    // Create a callable handle and a vector to store our outputs
    tensorflow::Session::CallableHandle handle = get_callable(tensor_feeds, tensor_fetches);
    std::vector<tensorflow::Tensor>     outputs;

    // Setup the inputs
    std::vector<tensorflow::Tensor> tf_inputs;
    tf_inputs.reserve(tensor_feeds.size());
    for (auto &item : tensor_feeds)
    {
        tf_inputs.emplace_back(std::move(item.second));
    }

    // Run the callable
    check_tf_status(session_->RunCallable(handle, tf_inputs, &outputs, nullptr));

    // Read the outputs and wrap them in `NeuropodTensor`s
    auto   to_return = stdx::make_unique<NeuropodValueMap>();
    size_t position  = 0;
    for (const auto &item : tensor_fetches)
    {
        const auto &output_name   = item.second;
        auto &      output_tensor = outputs[position++];
        const auto  tensor_type   = get_neuropod_type_from_tf_type(output_tensor.dtype());
        (*to_return)[output_name] = make_tensor<TensorflowNeuropodTensor>(tensor_type, std::move(output_tensor));
    }

    return to_return;
}