in src/relax/analysis/graph_partitioner.cc [324:441]
void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
const DominatorTree& post_dom_tree, //
int phase) {
for (size_t nid = 0; nid < groups_.size(); ++nid) {
// the group of current node has been specified already.
auto* graph_node = graph.post_dfs_order[nid];
auto* dom_node = post_dom_tree.nodes[nid];
Group* group_node = groups_[nid];
ICHECK(group_node != nullptr);
postpone_node_ = nullptr;
// Check if the fusing of some inputs was postponed
if (postponed_fusing_map_.count(graph_node)) {
auto range = postponed_fusing_map_.equal_range(graph_node);
for (auto it = range.first; it != range.second; ++it) {
// If the number of arguments is less than the limit then the input can be fused
if (CountArgs_(graph_node, graph, false) <= CountArgsLimit_(graph_node)) {
auto* src = it->second;
auto* snode = post_dom_tree.nodes[src->index]->parent->gnode;
if (groups_[snode->index]->anchor_ref != nullptr) continue;
CommitFuse(src, snode);
}
}
postponed_fusing_map_.erase(graph_node);
}
// no actions for opaque nodes
if (group_node->pattern == kOpaque) continue;
// no actions needed if the current node have no dominator
if (dom_node->parent == nullptr) continue;
ICHECK(!graph_node->extern_ref);
size_t dom_parent_gindex = dom_node->parent->gnode->index;
// refuse the fusion if too many ops are going to be fused together
if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
continue;
// Refuse the fusion if too many arguments are going to be in the fused function
if (max_function_args_ > 0) {
auto limit = CountArgsLimit_(graph_node);
if (limit > 0) {
if (CountFusedArgs(graph, graph_node) > limit) {
continue;
}
}
}
if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
if (group_node->pattern > kInjective) continue;
Group* dom_parent_group = groups_[dom_parent_gindex];
Group* dom_root_group = dom_parent_group->FindRoot();
// If dom node group has a tuple as its root, we do not fuse tuple fields into it
if (dom_root_group->pattern == kTuple) continue;
if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) {
// Now we know the tuple has been fused into subsequent injective ops
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
// dom_root_group can also be tuple, as in inception layers
// CheckPath is needed to avoid fusing two intermediate tuples
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
continue;
}
// Skip if current node is already fused to the parent.
if (groups_[dom_parent_gindex] != nullptr &&
group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
continue;
}
// Do not fuse into tuple for now
if (groups_[dom_parent_gindex]->pattern == kTuple) continue;
// Try to fuse current node to its post-dominator.
if (group_node->pattern == kOutEWiseFusable) {
if (phase != 0) continue;
// Path for OutEWiseFusable: conv2d
// Check if the dominator relation is elemwise.
if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
ICHECK(dom_node->parent->gnode != nullptr);
// The fuse can be executed if all the intermediate ops are still broadcast.
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
} else if (group_node->pattern <= kBroadcast) {
// Pre-condition: can only be fused to parent which is injective or reduction.
if (dom_node->parent != nullptr &&
(dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
// Check if all the intermediate ops are still broadcast.
// The final terminal node can already be fused to a OutEWiseFusable group.
auto fcond = [](OpPatternKind kind, bool is_sink) {
if (!is_sink) {
// Elemwise, broadcast, and injective ops on the parallel branches
// are allowed be fused to the elemwise/broadcast anchor.
return kind <= kInjective;
} else {
return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
kind == kOutEWiseFusable);
}
};
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
} else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
// defer injective fusion to second phase.
// so conv2d always finishes fusing.
if (phase != 1) continue;
// Check if all path are injective.
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
} else {
// do nothing.
ICHECK(group_node->pattern == kCommReduce);
}
}
}