in src/contrib/msc/core/ir/graph.cc [1021:1150]
void WeightGraphNode::Build(const MSCGraph& graph, const Map<String, Array<String>>& main_wtypes,
const Map<String, String>& relation_wtypes) {
auto sort_nodes = [&graph](const BaseJoint& node_a, const BaseJoint& node_b) {
return graph->FindProducer(node_a->name)->index < graph->FindProducer(node_b->name)->index;
};
auto find_parents = [this, &main_wtypes, &relation_wtypes, &sort_nodes](const MSCJoint& node) {
std::vector<BaseJoint> parents;
std::queue<MSCJoint> frontier;
std::set<MSCJoint> explored;
for (const auto& p : node->parents) {
frontier.push(Downcast<MSCJoint>(p));
}
while (!frontier.empty()) {
const auto& current = frontier.front();
if (explored.count(current)) {
frontier.pop();
continue;
}
explored.insert(current);
if (main_wtypes.count(current->optype)) {
for (const auto& t_type : main_wtypes[current->optype]) {
if (current->weights.count(t_type)) {
parents.push_back(FindNode(current->WeightAt(t_type)->name));
}
}
} else if (relation_wtypes.count(current->optype)) {
parents.push_back(FindNode(current->OutputAt(0)->name));
} else {
for (const auto& p : current->parents) {
const auto& new_parent = Downcast<MSCJoint>(p);
if (!explored.count(new_parent)) {
frontier.push(new_parent);
}
}
}
frontier.pop();
}
Array<BaseJoint> parents_array;
if (parents.size() > 1) {
std::sort(parents.begin(), parents.end(), sort_nodes);
}
for (const auto& p : parents) {
parents_array.push_back(p);
}
return parents_array;
};
for (const auto& n : graph->node_names) {
const auto& node = graph->FindNode(n);
if (node->shared_ref.size() > 0) {
continue;
}
if (main_wtypes.count(node->optype) || relation_wtypes.count(node->optype) ||
node->weights.size() > 0) {
const auto& w_parents = find_parents(node);
bool bind_friends = true;
if (relation_wtypes.count(node->optype) && relation_wtypes[node->optype] == "multi_inputs") {
bind_friends = false;
}
if (w_parents.size() > 1 && bind_friends) {
for (const auto& p : w_parents) {
Downcast<WeightJoint>(p)->friends = w_parents;
}
}
if (main_wtypes.count(node->optype)) {
for (const auto& wtype : main_wtypes[node->optype]) {
if (node->weights.count(wtype)) {
const auto& weight = node->WeightAt(wtype);
Map<String, String> attrs;
attrs.Set("producer_type", node->optype);
attrs.Set("weight_strategy", "main");
const auto& w_node =
WeightJoint(node_names.size(), weight->name, "", wtype, weight, w_parents, attrs);
for (const auto& p : w_parents) {
p->AddChild(w_node);
}
nodes.Set(weight->name, w_node);
node_names.push_back(weight->name);
}
}
const BaseJoint& head = FindNode(node_names[node_names.size() - 1]);
for (const auto& pair : node->weights) {
if (!nodes.count(pair.second->name)) {
Map<String, String> attrs;
attrs.Set("producer_type", node->optype);
attrs.Set("weight_strategy", "follow");
const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first,
pair.second, {head}, attrs);
head->AddChild(w_node);
nodes.Set(pair.second->name, w_node);
node_names.push_back(pair.second->name);
}
}
} else if (relation_wtypes.count(node->optype)) {
const auto& tensor = node->OutputAt(0);
Map<String, String> attrs;
attrs.Set("producer_type", node->optype);
if (node->optype == "reshape") {
// TODO(archermmt): check non-passby reshape
attrs.Set("weight_strategy", "passby");
} else {
attrs.Set("weight_strategy", relation_wtypes[node->optype]);
}
const auto& t_node =
WeightJoint(node_names.size(), tensor->name, "", "output", tensor, w_parents, attrs);
for (const auto& p : w_parents) {
p->AddChild(t_node);
}
nodes.Set(tensor->name, t_node);
node_names.push_back(tensor->name);
} else if (node->weights.size() > 0) {
for (const auto& pair : node->weights) {
if (!nodes.count(pair.second->name)) {
Map<String, String> attrs;
attrs.Set("producer_type", node->optype);
attrs.Set("weight_strategy", "follow");
const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first,
pair.second, w_parents, attrs);
for (const auto& p : w_parents) {
p->AddChild(w_node);
}
nodes.Set(pair.second->name, w_node);
node_names.push_back(pair.second->name);
}
}
}
}
}
}