void move_graph_to_device()

in source/neuropod/backends/tensorflow/tf_utils.cc [28:66]


void move_graph_to_device(tensorflow::GraphDef &graph, tensorflow::Session &session, const NeuropodDevice target)
{
    // Figure out the correct target device
    std::string target_device = "/device:CPU:0";
    if (target != Device::CPU)
    {
        // Get all the available devices
        std::vector<tensorflow::DeviceAttributes> devices;
        check_tf_status(session.ListDevices(&devices));

        // Check if we have any GPUs
        bool found_gpu = std::any_of(devices.begin(), devices.end(), [](const tensorflow::DeviceAttributes &device) {
            return device.device_type() == "GPU";
        });

        // If we have a GPU, update the target device
        if (found_gpu)
        {
            target_device = std::string("/device:GPU:") + std::to_string(target);
        }
    }

    // Iterate through all the nodes in the graph and move them to the target device
    for (auto &node : *graph.mutable_node())
    {
        const auto &node_device = node.device();

        // If a node is on CPU, leave it there
        if (node_device != "/device:CPU:0" && node_device != target_device)
        {
            SPDLOG_TRACE("TF: Moving node {} from device {} to device {}", node.name(), node_device, target_device);
            node.set_device(target_device);
        }
        else
        {
            SPDLOG_TRACE("TF: Leaving node {} on device {}", node.name(), node_device);
        }
    }
}