in src/contrib/msc/core/ir/graph_builder.cc [136:254]
const MSCGraph GraphBuilder::Build(const Function& func) {
// Add input nodes and record inputs;
Array<String> input_names, output_names;
std::set<String> added_inputs;
// Add prims
for (const auto& p : func->params) {
if (!p->struct_info_.defined()) {
continue;
}
if (p->struct_info_.value()->IsInstance<TensorStructInfoNode>()) {
const auto& shape = ExprUtils::GetShape(p, false);
for (size_t i = 0; i < shape.size(); i++) {
if (shape[i]->IsInstance<tvm::tir::VarNode>()) {
Map<String, String> attrs;
attrs.Set("producer", p->name_hint());
attrs.Set("out_idx", "0");
attrs.Set("dim", std::to_string(i));
MatchOrCreatePrim(shape[i], "shape", Array<BaseJoint>(), attrs);
}
}
} else {
LOG_FATAL << "Unexpected func param " << p << "(" << p->GetTypeKey() << ")";
}
}
for (const auto& p : func->params) {
if (expr_tensor_map_.count(p)) {
continue;
}
if (func_params_.count(p) && func_params_[p]->IsInstance<ExternFuncNode>()) {
continue;
}
if (func_params_.count(p) && func_params_[p]->IsInstance<TupleNode>()) {
const auto& tuple = Downcast<Tuple>(func_params_[p]);
Array<String> tuple_names;
for (const auto& f : tuple->fields) {
if (expr_tensor_map_.count(f)) {
LOG_INFO << "Replica tuple input " << f;
} else if (const auto* f_node = f.as<VarNode>()) {
AddNode(f, NullOpt, f_node->name_hint());
} else {
LOG_FATAL << "Unexpected tuple input " << f << "(" << f->GetTypeKey() << ")";
}
ICHECK(expr_tensor_map_.count(f)) << "Can not find func param from tuple " << f;
for (const auto& name : expr_tensor_map_[f]) {
tuple_names.push_back(name);
}
}
expr_tensor_map_.Set(p, tuple_names);
} else {
AddNode(p, NullOpt, p->name_hint());
}
ICHECK(expr_tensor_map_.count(p)) << "Can not find func param " << p;
for (const auto& name : expr_tensor_map_[p]) {
if (!added_inputs.count(name)) {
input_names.push_back(name);
added_inputs.insert(name);
}
}
}
VisitExpr(func);
ICHECK(expr_tensor_map_.count(func->body->body))
<< "Can not find seqexpr body " << func->body->body;
output_names = expr_tensor_map_[func->body->body];
// remove const nodes as weights
Array<MSCJoint> valid_nodes;
std::set<String> ignore_inputs;
for (const auto& n : nodes_) {
if (weights_.count(n->name) || ignore_nodes_.count(n->name)) {
for (const auto& o : n->outputs) {
ignore_inputs.insert(o->name);
}
} else {
n->index = valid_nodes.size();
valid_nodes.push_back(n);
if (n->optype != "input") {
for (const auto& o : n->outputs) {
ignore_inputs.insert(o->name);
}
}
}
}
// remove uselese inputs
Array<String> valid_inputs;
for (const auto& i : input_names) {
if (!ignore_inputs.count(i)) {
valid_inputs.push_back(i);
}
}
// build graph
const auto& graph = MSCGraph(name_, valid_nodes, valid_inputs, output_names, prims_);
// set inputs and outputs alias
if (config_.input_aliases.size() == valid_inputs.size()) {
for (size_t i = 0; i < valid_inputs.size(); i++) {
graph->FindTensor(valid_inputs[i])->alias = config_.input_aliases[i];
}
} else {
for (size_t i = 0; i < valid_inputs.size(); i++) {
graph->FindTensor(valid_inputs[i])->alias = graph->FindProducer(valid_inputs[i])->name;
}
}
if (config_.output_aliases.size() == output_names.size()) {
for (size_t i = 0; i < output_names.size(); i++) {
graph->FindTensor(output_names[i])->alias = config_.output_aliases[i];
}
} else {
for (size_t i = 0; i < output_names.size(); i++) {
const auto& output = graph->FindTensor(output_names[i]);
if (output->alias.size() > 0) {
continue;
}
const auto& producer = graph->FindProducer(output_names[i]);
output->alias = producer->outputs.size() == 1
? producer->name
: StringUtils::Replace(output_names[i], ":", "_");
}
}
return graph;
}