const MSCGraph GraphBuilder::Build()

in src/contrib/msc/core/ir/graph_builder.cc [136:254]


const MSCGraph GraphBuilder::Build(const Function& func) {
  // Add input nodes and record inputs;
  Array<String> input_names, output_names;
  std::set<String> added_inputs;
  // Add prims
  for (const auto& p : func->params) {
    if (!p->struct_info_.defined()) {
      continue;
    }
    if (p->struct_info_.value()->IsInstance<TensorStructInfoNode>()) {
      const auto& shape = ExprUtils::GetShape(p, false);
      for (size_t i = 0; i < shape.size(); i++) {
        if (shape[i]->IsInstance<tvm::tir::VarNode>()) {
          Map<String, String> attrs;
          attrs.Set("producer", p->name_hint());
          attrs.Set("out_idx", "0");
          attrs.Set("dim", std::to_string(i));
          MatchOrCreatePrim(shape[i], "shape", Array<BaseJoint>(), attrs);
        }
      }
    } else {
      LOG_FATAL << "Unexpected func param " << p << "(" << p->GetTypeKey() << ")";
    }
  }

  for (const auto& p : func->params) {
    if (expr_tensor_map_.count(p)) {
      continue;
    }
    if (func_params_.count(p) && func_params_[p]->IsInstance<ExternFuncNode>()) {
      continue;
    }
    if (func_params_.count(p) && func_params_[p]->IsInstance<TupleNode>()) {
      const auto& tuple = Downcast<Tuple>(func_params_[p]);
      Array<String> tuple_names;
      for (const auto& f : tuple->fields) {
        if (expr_tensor_map_.count(f)) {
          LOG_INFO << "Replica tuple input " << f;
        } else if (const auto* f_node = f.as<VarNode>()) {
          AddNode(f, NullOpt, f_node->name_hint());
        } else {
          LOG_FATAL << "Unexpected tuple input " << f << "(" << f->GetTypeKey() << ")";
        }
        ICHECK(expr_tensor_map_.count(f)) << "Can not find func param from tuple " << f;
        for (const auto& name : expr_tensor_map_[f]) {
          tuple_names.push_back(name);
        }
      }
      expr_tensor_map_.Set(p, tuple_names);
    } else {
      AddNode(p, NullOpt, p->name_hint());
    }
    ICHECK(expr_tensor_map_.count(p)) << "Can not find func param " << p;
    for (const auto& name : expr_tensor_map_[p]) {
      if (!added_inputs.count(name)) {
        input_names.push_back(name);
        added_inputs.insert(name);
      }
    }
  }
  VisitExpr(func);
  ICHECK(expr_tensor_map_.count(func->body->body))
      << "Can not find seqexpr body " << func->body->body;
  output_names = expr_tensor_map_[func->body->body];
  // remove const nodes as weights
  Array<MSCJoint> valid_nodes;
  std::set<String> ignore_inputs;
  for (const auto& n : nodes_) {
    if (weights_.count(n->name) || ignore_nodes_.count(n->name)) {
      for (const auto& o : n->outputs) {
        ignore_inputs.insert(o->name);
      }
    } else {
      n->index = valid_nodes.size();
      valid_nodes.push_back(n);
      if (n->optype != "input") {
        for (const auto& o : n->outputs) {
          ignore_inputs.insert(o->name);
        }
      }
    }
  }
  // remove uselese inputs
  Array<String> valid_inputs;
  for (const auto& i : input_names) {
    if (!ignore_inputs.count(i)) {
      valid_inputs.push_back(i);
    }
  }
  // build graph
  const auto& graph = MSCGraph(name_, valid_nodes, valid_inputs, output_names, prims_);
  // set inputs and outputs alias
  if (config_.input_aliases.size() == valid_inputs.size()) {
    for (size_t i = 0; i < valid_inputs.size(); i++) {
      graph->FindTensor(valid_inputs[i])->alias = config_.input_aliases[i];
    }
  } else {
    for (size_t i = 0; i < valid_inputs.size(); i++) {
      graph->FindTensor(valid_inputs[i])->alias = graph->FindProducer(valid_inputs[i])->name;
    }
  }
  if (config_.output_aliases.size() == output_names.size()) {
    for (size_t i = 0; i < output_names.size(); i++) {
      graph->FindTensor(output_names[i])->alias = config_.output_aliases[i];
    }
  } else {
    for (size_t i = 0; i < output_names.size(); i++) {
      const auto& output = graph->FindTensor(output_names[i]);
      if (output->alias.size() > 0) {
        continue;
      }
      const auto& producer = graph->FindProducer(output_names[i]);
      output->alias = producer->outputs.size() == 1
                          ? producer->name
                          : StringUtils::Replace(output_names[i], ":", "_");
    }
  }
  return graph;
}