in src/core/scheduler/scheduler.cc [579:669]
void Graph::AnalyzeNodes() {
if (in_serial_) {
begin_nodes_.push_back(nodes_[0]);
for (size_t i = 0; i < nodes_.size(); ++i) {
Node *curNode = nodes_[i];
next_nodes_[i].clear();
if (i + 1 < nodes_.size()) {
next_nodes_[i].push_back(nodes_[i + 1]);
}
BlockSet blks;
for (size_t j = 0; j < curNode->in_edges_.size(); ++j) {
blks.insert(curNode->in_edges_[j]->blk_);
}
for (size_t j = 0; j < curNode->out_edges_.size(); ++j) {
blks.insert(curNode->out_edges_[j]->blk_);
}
for (auto &it : blks) {
blocks_[it]->used_nodes_.push_back(curNode);
}
}
} else {
// init node ref
std::vector<int> node_ref_;
node_ref_.resize(nodes_.size());
for (size_t i = 0; i < nodes_.size(); ++i) {
node_ref_[i] = nodes_[i]->in_edges_.size();
}
// find all input edges and decrease ref count of nodes
for (size_t i = 0; i < edges_.size(); ++i) {
Node *src_node = edges_[i]->src_node_;
if (!src_node) {
Node *node = edges_[i]->dst_node_;
int nodeId = node->id_;
node_ref_[nodeId] -= 1;
}
}
// activate nodes
SafeQueue<Node *> node_queue;
for (size_t i = 0; i < node_ref_.size(); ++i) {
if (node_ref_[i] == 0) {
begin_nodes_.push_back(nodes_[i]);
node_queue.Push(nodes_[i]);
}
}
// run graph
while (node_queue.Size()) {
// step 1: pop the first element, get the node corresponding to the index
Node *curNode = nullptr;
node_queue.Pop(curNode);
int curIndex = curNode->id_;
// step 2: decrease ref count of nodes and activate nodes
next_nodes_[curIndex].clear();
for (size_t i = 0; i < curNode->out_edges_.size(); ++i) {
Edge *edge = curNode->out_edges_[i];
Node *nextNode = edge->dst_node_;
if (!nextNode) {
continue;
}
int nodeId = nextNode->id_;
node_ref_[nodeId] -= 1;
if (node_ref_[nodeId] <= 0) {
node_queue.Push(nextNode);
next_nodes_[curIndex].push_back(nextNode);
}
}
// step 3: push_back curNode to the used_nodes_ of relevant blocks
BlockSet blks;
for (size_t j = 0; j < curNode->in_edges_.size(); ++j) {
blks.insert(curNode->in_edges_[j]->blk_);
}
for (size_t j = 0; j < curNode->out_edges_.size(); ++j) {
blks.insert(curNode->out_edges_[j]->blk_);
}
for (auto &it : blks) {
blocks_[it]->used_nodes_.push_back(curNode);
}
}
}
}