void WeightGraphNode::Build()

in src/contrib/msc/core/ir/graph.cc [1021:1150]


void WeightGraphNode::Build(const MSCGraph& graph, const Map<String, Array<String>>& main_wtypes,
                            const Map<String, String>& relation_wtypes) {
  auto sort_nodes = [&graph](const BaseJoint& node_a, const BaseJoint& node_b) {
    return graph->FindProducer(node_a->name)->index < graph->FindProducer(node_b->name)->index;
  };

  auto find_parents = [this, &main_wtypes, &relation_wtypes, &sort_nodes](const MSCJoint& node) {
    std::vector<BaseJoint> parents;
    std::queue<MSCJoint> frontier;
    std::set<MSCJoint> explored;
    for (const auto& p : node->parents) {
      frontier.push(Downcast<MSCJoint>(p));
    }
    while (!frontier.empty()) {
      const auto& current = frontier.front();
      if (explored.count(current)) {
        frontier.pop();
        continue;
      }
      explored.insert(current);
      if (main_wtypes.count(current->optype)) {
        for (const auto& t_type : main_wtypes[current->optype]) {
          if (current->weights.count(t_type)) {
            parents.push_back(FindNode(current->WeightAt(t_type)->name));
          }
        }
      } else if (relation_wtypes.count(current->optype)) {
        parents.push_back(FindNode(current->OutputAt(0)->name));
      } else {
        for (const auto& p : current->parents) {
          const auto& new_parent = Downcast<MSCJoint>(p);
          if (!explored.count(new_parent)) {
            frontier.push(new_parent);
          }
        }
      }
      frontier.pop();
    }
    Array<BaseJoint> parents_array;
    if (parents.size() > 1) {
      std::sort(parents.begin(), parents.end(), sort_nodes);
    }
    for (const auto& p : parents) {
      parents_array.push_back(p);
    }
    return parents_array;
  };

  for (const auto& n : graph->node_names) {
    const auto& node = graph->FindNode(n);
    if (node->shared_ref.size() > 0) {
      continue;
    }
    if (main_wtypes.count(node->optype) || relation_wtypes.count(node->optype) ||
        node->weights.size() > 0) {
      const auto& w_parents = find_parents(node);
      bool bind_friends = true;
      if (relation_wtypes.count(node->optype) && relation_wtypes[node->optype] == "multi_inputs") {
        bind_friends = false;
      }
      if (w_parents.size() > 1 && bind_friends) {
        for (const auto& p : w_parents) {
          Downcast<WeightJoint>(p)->friends = w_parents;
        }
      }
      if (main_wtypes.count(node->optype)) {
        for (const auto& wtype : main_wtypes[node->optype]) {
          if (node->weights.count(wtype)) {
            const auto& weight = node->WeightAt(wtype);
            Map<String, String> attrs;
            attrs.Set("producer_type", node->optype);
            attrs.Set("weight_strategy", "main");
            const auto& w_node =
                WeightJoint(node_names.size(), weight->name, "", wtype, weight, w_parents, attrs);
            for (const auto& p : w_parents) {
              p->AddChild(w_node);
            }
            nodes.Set(weight->name, w_node);
            node_names.push_back(weight->name);
          }
        }
        const BaseJoint& head = FindNode(node_names[node_names.size() - 1]);
        for (const auto& pair : node->weights) {
          if (!nodes.count(pair.second->name)) {
            Map<String, String> attrs;
            attrs.Set("producer_type", node->optype);
            attrs.Set("weight_strategy", "follow");
            const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first,
                                             pair.second, {head}, attrs);
            head->AddChild(w_node);
            nodes.Set(pair.second->name, w_node);
            node_names.push_back(pair.second->name);
          }
        }
      } else if (relation_wtypes.count(node->optype)) {
        const auto& tensor = node->OutputAt(0);
        Map<String, String> attrs;
        attrs.Set("producer_type", node->optype);
        if (node->optype == "reshape") {
          // TODO(archermmt): check non-passby reshape
          attrs.Set("weight_strategy", "passby");
        } else {
          attrs.Set("weight_strategy", relation_wtypes[node->optype]);
        }
        const auto& t_node =
            WeightJoint(node_names.size(), tensor->name, "", "output", tensor, w_parents, attrs);
        for (const auto& p : w_parents) {
          p->AddChild(t_node);
        }
        nodes.Set(tensor->name, t_node);
        node_names.push_back(tensor->name);
      } else if (node->weights.size() > 0) {
        for (const auto& pair : node->weights) {
          if (!nodes.count(pair.second->name)) {
            Map<String, String> attrs;
            attrs.Set("producer_type", node->optype);
            attrs.Set("weight_strategy", "follow");
            const auto& w_node = WeightJoint(node_names.size(), pair.second->name, "", pair.first,
                                             pair.second, w_parents, attrs);
            for (const auto& p : w_parents) {
              p->AddChild(w_node);
            }
            nodes.Set(pair.second->name, w_node);
            node_names.push_back(pair.second->name);
          }
        }
      }
    }
  }
}