void Parallel_data_reader::init_graph()

in src/mlio/parallel_data_reader.cc [187:282]


void Parallel_data_reader::init_graph()
{
    namespace flw = tbb::flow;

    std::size_t num_prefetched_examples = params().num_prefetched_examples;
    if (num_prefetched_examples == 0) {
        // Defaults to the number of processor cores.
        num_prefetched_examples =
            static_cast<std::size_t>(tbb::task_scheduler_init::default_num_threads());
    }

    std::size_t num_parallel_reads = params().num_parallel_reads;
    if (num_parallel_reads == 0 || num_parallel_reads > num_prefetched_examples) {
        num_parallel_reads = num_prefetched_examples;
    }

    flw::graph &g = graph_->obj;

    // Source
    auto src_node = std::make_unique<flw::source_node<Batch_msg>>(
        g,
        [this](auto &msg) {
            std::optional<Instance_batch> batch = batch_reader_->read_instance_batch();
            if (batch == std::nullopt) {
                return false;
            }

            msg = Batch_msg{std::make_shared<Instance_batch>(std::move(*batch))};

            return true;
        },
        false);

    // Limiter
    auto limit_node = std::make_unique<flw::limiter_node<Batch_msg>>(g, num_parallel_reads);

    // Decode
    auto decode_node =
        std::make_unique<flw::multifunction_node<Batch_msg, std::tuple<Example_msg>>>(
            g, flw::unlimited, [this](const auto &msg, auto &ports) {
                // We send a message to the next node even if the decode
                // function fails. This is needed to have correct
                // sequential ordering of other batches.
                Example_msg out{msg.batch->index(), this->decode(*msg.batch)};

                if (out.example != nullptr) {
                    num_bytes_read_.fetch_add(msg.batch->size_bytes());
                }

                std::get<0>(ports).try_put(std::move(out));
            });

    // Order
    auto order_node = std::make_unique<flw::sequencer_node<Example_msg>>(g, [](const auto &msg) {
        return msg.idx;
    });

    // Queue
    auto queue_node = std::make_unique<flw::function_node<Example_msg, flw::continue_msg>>(
        g, flw::serial, [this, num_prefetched_examples](const auto &msg) {
            // If the decode function has failed discard the message.
            if (msg.example == nullptr) {
                return;
            }

            {
                std::unique_lock<std::mutex> queue_lock{queue_mutex_};

                fill_condition_.wait(queue_lock, [this, num_prefetched_examples] {
                    return fill_queue_.size() < num_prefetched_examples;
                });

                if (graph_->ctx.is_group_execution_cancelled()) {
                    return;
                }

                fill_queue_.push_back(msg.example);
            }

            read_condition_.notify_one();
        });

    flw::make_edge(*src_node, *limit_node);
    flw::make_edge(*limit_node, *decode_node);
    flw::make_edge(flw::output_port<0>(*decode_node), *order_node);
    flw::make_edge(*order_node, *queue_node);
    flw::make_edge(*queue_node, limit_node->decrement);

    graph_->src_node = src_node.get();

    graph_->nodes.emplace_back(std::move(src_node));
    graph_->nodes.emplace_back(std::move(limit_node));
    graph_->nodes.emplace_back(std::move(decode_node));
    graph_->nodes.emplace_back(std::move(order_node));
    graph_->nodes.emplace_back(std::move(queue_node));
}