int64_t TensorflowNeuropodBackend::get_callable()

in source/neuropod/backends/tensorflow/tf_backend.cc [270:314]


int64_t TensorflowNeuropodBackend::get_callable(const std::map<std::string, tensorflow::Tensor> &tensor_feeds,
                                                const std::map<std::string, std::string> &       tensor_fetches)
{
    tensorflow::Session::CallableHandle handle{};

    const auto cache_key     = get_handle_cache_key(tensor_feeds, tensor_fetches);
    auto       cached_handle = callable_handle_cache_.find(cache_key);
    if (cached_handle != callable_handle_cache_.end())
    {
        // Cache hit!
        handle = cached_handle->second;
    }
    else
    {
        // Cache miss...
        SPDLOG_DEBUG("TF: Callable cache miss. Creating new callable...");

        // Used for setting the inputs and outputs of the subgraph we want to run
        tensorflow::CallableOptions opts;

        for (const auto &item : tensor_feeds)
        {
            // item.first is the node name in the TF graph
            opts.add_feed(item.first);

            // TODO(vip): Once we explicitly control devices, do something like this:
            // opts.mutable_feed_devices()->insert({item, device_name});
        }

        for (const auto &item : tensor_fetches)
        {
            // item.first is the node name in the TF graph
            opts.add_fetch(item.first);
        }

        // Make the callable using the options we set above
        // Note: this callable will be released in the destructor
        check_tf_status(session_->MakeCallable(opts, &handle));

        // Add it to our cache
        callable_handle_cache_[cache_key] = handle;
    }

    return handle;
}