in src/mlio/parallel_data_reader.cc [187:282]
void Parallel_data_reader::init_graph()
{
namespace flw = tbb::flow;
std::size_t num_prefetched_examples = params().num_prefetched_examples;
if (num_prefetched_examples == 0) {
// Defaults to the number of processor cores.
num_prefetched_examples =
static_cast<std::size_t>(tbb::task_scheduler_init::default_num_threads());
}
std::size_t num_parallel_reads = params().num_parallel_reads;
if (num_parallel_reads == 0 || num_parallel_reads > num_prefetched_examples) {
num_parallel_reads = num_prefetched_examples;
}
flw::graph &g = graph_->obj;
// Source
auto src_node = std::make_unique<flw::source_node<Batch_msg>>(
g,
[this](auto &msg) {
std::optional<Instance_batch> batch = batch_reader_->read_instance_batch();
if (batch == std::nullopt) {
return false;
}
msg = Batch_msg{std::make_shared<Instance_batch>(std::move(*batch))};
return true;
},
false);
// Limiter
auto limit_node = std::make_unique<flw::limiter_node<Batch_msg>>(g, num_parallel_reads);
// Decode
auto decode_node =
std::make_unique<flw::multifunction_node<Batch_msg, std::tuple<Example_msg>>>(
g, flw::unlimited, [this](const auto &msg, auto &ports) {
// We send a message to the next node even if the decode
// function fails. This is needed to have correct
// sequential ordering of other batches.
Example_msg out{msg.batch->index(), this->decode(*msg.batch)};
if (out.example != nullptr) {
num_bytes_read_.fetch_add(msg.batch->size_bytes());
}
std::get<0>(ports).try_put(std::move(out));
});
// Order
auto order_node = std::make_unique<flw::sequencer_node<Example_msg>>(g, [](const auto &msg) {
return msg.idx;
});
// Queue
auto queue_node = std::make_unique<flw::function_node<Example_msg, flw::continue_msg>>(
g, flw::serial, [this, num_prefetched_examples](const auto &msg) {
// If the decode function has failed discard the message.
if (msg.example == nullptr) {
return;
}
{
std::unique_lock<std::mutex> queue_lock{queue_mutex_};
fill_condition_.wait(queue_lock, [this, num_prefetched_examples] {
return fill_queue_.size() < num_prefetched_examples;
});
if (graph_->ctx.is_group_execution_cancelled()) {
return;
}
fill_queue_.push_back(msg.example);
}
read_condition_.notify_one();
});
flw::make_edge(*src_node, *limit_node);
flw::make_edge(*limit_node, *decode_node);
flw::make_edge(flw::output_port<0>(*decode_node), *order_node);
flw::make_edge(*order_node, *queue_node);
flw::make_edge(*queue_node, limit_node->decrement);
graph_->src_node = src_node.get();
graph_->nodes.emplace_back(std::move(src_node));
graph_->nodes.emplace_back(std::move(limit_node));
graph_->nodes.emplace_back(std::move(decode_node));
graph_->nodes.emplace_back(std::move(order_node));
graph_->nodes.emplace_back(std::move(queue_node));
}