in src/contrib/msc/core/ir/graph_builder.cc [256:585]
const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional<Expr>& binding_var,
const String& name) {
// Get optype, node_name and layout
String node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, msc_attr::kName);
String optype = "unknown";
String layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout);
if (func_params_.count(expr) && func_params_[expr]->IsInstance<ConstantNode>()) {
node_name = SpanUtils::GetAttr(func_params_[expr]->span, msc_attr::kName);
optype = "constant";
} else if (expr->IsInstance<VarNode>()) {
optype = "input";
} else if (expr->IsInstance<ConstantNode>()) {
optype = "constant";
} else if (expr->IsInstance<ShapeExprNode>()) {
optype = "shape";
} else if (expr->IsInstance<TupleGetItemNode>()) {
optype = "get_item";
} else if (expr->IsInstance<TupleNode>()) {
optype = "tuple";
} else if (const auto* call_node = expr.as<CallNode>()) {
if (const auto* op_node = call_node->op.as<OpNode>()) {
if (op_node->name == "relax.call_dps_packed") {
optype = Downcast<ExternFunc>(call_node->args[0])->global_symbol;
} else {
optype = StringUtils::Replace(op_node->name, "relax.", "");
}
} else if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
const auto& func = Downcast<Function>(ref_module_->Lookup(v_node->name_hint));
std::tie(node_name, optype, layout) = ParseFunc(func);
} else if (call_node->op->IsInstance<VarNode>()) {
ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op;
std::tie(node_name, optype, layout) = ParseFunc(target_funcs_[call_node->op]);
} else if (call_node->op->IsInstance<FunctionNode>()) {
std::tie(node_name, optype, layout) = ParseFunc(Downcast<Function>(call_node->op));
}
}
if (layouts_.count(node_name)) {
layout = layouts_[node_name];
}
// specail case for tuple
if (optype == "tuple" && expr->IsInstance<CallNode>() &&
Downcast<Call>(expr)->op->IsInstance<VarNode>()) {
const auto& call_node = Downcast<Call>(expr);
ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op;
const auto& tuple_func = target_funcs_[call_node->op];
for (size_t i = 0; i < call_node->args.size(); i++) {
expr_tensor_map_.Set(tuple_func->params[i], expr_tensor_map_[call_node->args[i]]);
}
VisitExpr(tuple_func);
ICHECK(expr_tensor_map_.count(tuple_func->body->body))
<< "Can not find seqexpr body " << tuple_func->body->body;
const auto& outputs = expr_tensor_map_[tuple_func->body->body];
const auto& ref_expr = binding_var.defined() ? binding_var.value() : expr;
expr_tensor_map_.Set(ref_expr, outputs);
ICHECK(tensor_input_map_.count(outputs[0])) << "Can not find tensor " << outputs[0];
return Downcast<MSCJoint>(tensor_input_map_[outputs[0]].first);
}
// get plugin
const auto& plugin = IsPlugin(optype) ? GetPlugin(optype) : Plugin();
// Extract normal attributes
Map<String, String> attrs;
if (plugin.defined()) {
const auto& op = Downcast<Call>(expr)->op;
if (target_funcs_.count(op)) {
const auto& opattrs_opt = target_funcs_[op]->GetAttr<Array<String>>(msc_attr::kOpattrs);
if (opattrs_opt.defined()) {
const auto& opattrs = opattrs_opt.value();
ICHECK_EQ(opattrs.size(), plugin->attrs.size())
<< "opattrs " << opattrs << " size mismatch with " << plugin->attrs.size();
for (size_t i = 0; i < opattrs.size(); i++) {
attrs.Set(plugin->attrs[i]->name, opattrs[i]);
}
}
} else {
const auto& args = GetPluginInputs(expr);
for (size_t i = 0; i < plugin->attrs.size(); i++) {
const auto& val = args[plugin->inputs.size() + i];
attrs.Set(plugin->attrs[i]->name, StringUtils::ToString(val));
}
}
} else if (const auto* call_node = expr.as<CallNode>()) {
if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
const auto& func = Downcast<Function>(ref_module_->Lookup(v_node->name_hint));
const auto& name_opt = func->GetAttr<runtime::String>(relax::attr::kComposite);
if (name_opt.defined()) {
attrs = FuncAttrGetter().GetAttrs(func);
}
} else if (call_node->op->IsInstance<VarNode>()) {
ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op;
attrs = FuncAttrGetter().GetAttrs(target_funcs_[call_node->op]);
} else if (call_node->op->IsInstance<FunctionNode>()) {
attrs = FuncAttrGetter().GetAttrs(call_node->op);
} else if (call_node->attrs.defined()) {
AttrGetter getter(&attrs);
const_cast<BaseAttrsNode*>(call_node->attrs.get())->VisitAttrs(&getter);
}
} else if (const auto* const_node = expr.as<ConstantNode>()) {
if (const_node->is_scalar()) {
attrs.Set("scalar", GetScalarStr(const_node->data, config_.float_precision));
}
} else if (const auto* shape_node = expr.as<ShapeExprNode>()) {
attrs.Set("shape", StringUtils::ToString(shape_node->values));
} else if (const auto* get_node = expr.as<TupleGetItemNode>()) {
attrs.Set("index", std::to_string(get_node->index));
}
// Extract attributes from arguments
Array<String> input_types;
if (!plugin.defined() && expr->IsInstance<CallNode>()) {
const auto& call = Downcast<Call>(expr);
Array<String> values;
if (call->op->IsInstance<VarNode>()) {
ICHECK(target_funcs_.count(call->op)) << "Can not find target func: " << call->op;
values = FuncValueGetter().GetValues(target_funcs_[call->op]);
}
input_types = ExprUtils::GetInputTypes(optype, call->args.size() + values.size(), true);
for (size_t i = 0; i < call->args.size(); i++) {
const auto& arg = call->args[i];
if (const auto* s_node = arg.as<ShapeExprNode>()) {
attrs.Set(input_types[i], StringUtils::ToString(s_node->values));
} else if (func_params_.count(arg) && func_params_[arg]->IsInstance<ShapeExprNode>()) {
const auto* s_node = func_params_[arg].as<ShapeExprNode>();
attrs.Set(input_types[i], StringUtils::ToString(s_node->values));
ignore_nodes_.insert(Downcast<Var>(arg)->name_hint());
} else if (const auto* s_node = arg.as<PrimValueNode>()) {
ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype
<< " should has special type, get " << input_types;
attrs.Set(input_types[i], StringUtils::ToString(s_node->value));
} else if (input_types[i] != "input" && arg->IsInstance<TupleNode>()) {
attrs.Set(input_types[i], StringUtils::ToString(arg));
}
}
for (size_t i = call->args.size(); i < input_types.size(); i++) {
attrs.Set(input_types[i], values[i - call->args.size()]);
}
}
// Build inputs and weights
Array<String> input_names;
Map<String, MSCTensor> node_weights;
if (plugin.defined()) {
const auto& call = Downcast<Call>(expr);
if (call->args.size() == 1) {
ICHECK(expr_tensor_map_.count(call->args[0]))
<< "Can not find tuple plugin input " << call->args[0];
input_names = expr_tensor_map_[call->args[0]];
} else {
const auto& args = GetPluginInputs(expr);
for (size_t i = 0; i < plugin->inputs.size(); i++) {
ICHECK(expr_tensor_map_.count(args[i])) << "Can not find plugin input " << args[i];
for (const auto& in_name : expr_tensor_map_[args[i]]) {
input_names.push_back(in_name);
}
}
}
} else if (const auto* call_node = expr.as<CallNode>()) {
for (size_t i = 0; i < call_node->args.size(); i++) {
if (attrs.count(input_types[i])) {
continue;
}
const auto& arg = call_node->args[i];
Array<String> arg_names;
if (expr_tensor_map_.count(arg)) {
arg_names = expr_tensor_map_[arg];
} else if (input_types[i] == "input" && arg->IsInstance<TupleNode>()) {
const auto* tuple_node = arg.as<TupleNode>();
for (const auto& f : tuple_node->fields) {
ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f;
for (const auto& in_name : expr_tensor_map_[f]) {
arg_names.push_back(in_name);
}
}
}
String weight_name;
if (input_types[i] != "input" && arg->IsInstance<ConstantNode>()) {
weight_name = SpanUtils::GetAttr(arg->span, msc_attr::kName);
} else if (input_types[i] != "input" && func_params_.count(arg) &&
func_params_[arg]->IsInstance<ConstantNode>()) {
weight_name = SpanUtils::GetAttr(func_params_[arg]->span, msc_attr::kName);
ignore_nodes_.insert(Downcast<Var>(arg)->name_hint());
}
// set weights or inputs
if (weight_name.size() > 0) {
const auto& t_name = arg_names[0];
const auto& pair = tensor_input_map_[t_name];
const auto& producer = Downcast<MSCJoint>(pair.first);
if (!weights_.count(weight_name)) {
const auto& ref = producer->OutputAt(pair.second);
MSCTensor weight;
if (input_types[i] == "bias") {
weight = MSCTensor(weight_name, ref->dtype, "O", Array<Integer>{ref->GetSize()});
} else if (input_types[i] == "weight" &&
(optype == "msc.linear" || optype == "msc.linear_bias")) {
if (ref->layout.name() == "IO") {
String valid_layout = ref->layout[1].name() + ref->layout[0].name();
const auto& valid_shape = Array<Integer>({ref->shape[1], ref->shape[0]});
weight = MSCTensor(weight_name, ref->dtype, valid_layout, valid_shape);
} else {
weight = MSCTensor(weight_name, ref->dtype, ref->layout.name(), ref->shape);
}
} else {
weight = MSCTensor(weight_name, ref->dtype, ref->layout.name(), ref->shape);
}
weights_.Set(weight_name, weight);
}
if (producer->HasAttr("scalar")) {
attrs.Set(input_types[i], producer->GetTypeAttr<std::string>("scalar"));
}
node_weights.Set(input_types[i], weights_[weight_name]);
} else {
for (const auto& in_name : arg_names) {
input_names.push_back(in_name);
}
}
}
} else if (const auto* tuple_node = expr.as<TupleNode>()) {
for (const auto& f : tuple_node->fields) {
ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f;
for (const auto& in_name : expr_tensor_map_[f]) {
input_names.push_back(in_name);
}
}
} else if (const auto* getitem_node = expr.as<TupleGetItemNode>()) {
ICHECK(expr_tensor_map_.count(getitem_node->tuple))
<< "Can not find tuple " << getitem_node->tuple;
input_names = expr_tensor_map_[getitem_node->tuple];
} else if (optype == "constant") {
const auto& t_info = Downcast<TensorStructInfo>(GetStructInfo(expr));
const auto& shape_opt = t_info->GetShape();
ICHECK(shape_opt.defined()) << "Constant shape is not defined";
const auto& weight =
MSCTensor(node_name, t_info->dtype, layout, ArrayUtils::Cast<Integer>(shape_opt.value()));
node_weights.Set("const", weight);
}
std::vector<std::pair<BaseJoint, size_t>> inputs;
for (const auto& i : input_names) {
inputs.push_back(tensor_input_map_[i]);
}
// Redefine layout for special ops
if (optype == "tuple") {
layout = "";
for (size_t i = 0; i < inputs.size(); i++) {
const auto& in_tensor = Downcast<MSCJoint>(inputs[i].first)->OutputAt(inputs[i].second);
layout = layout + in_tensor->layout.name();
layout = layout + (i == inputs.size() - 1 ? "" : ",");
}
} else if (optype == "get_item") {
int idx = std::stoi(attrs["index"]);
const auto& in_tensor = Downcast<MSCJoint>(inputs[idx].first)->OutputAt(inputs[idx].second);
layout = in_tensor->layout.name();
}
// Build output tensor
auto build_output = [this](const StructInfo& sinfo, const String& node_name,
const String& layout) {
ICHECK(sinfo->IsInstance<TensorStructInfoNode>())
<< "sinfo should be TensorStructInfo, get " << sinfo->GetTypeKey();
const auto& t_info = Downcast<TensorStructInfo>(sinfo);
const auto& shape = ArrayUtils::Cast<Integer>(ExprUtils::GetShape(t_info));
Array<String> prims;
bool has_prims = false;
if (shape.size() > 0) {
for (const auto& s : t_info->GetShape().value()) {
if (prim_map_.count(s)) {
prims.push_back(prim_map_[s]->name);
has_prims = true;
} else {
prims.push_back(StringUtils::ToString(s));
}
}
}
if (has_prims) {
return MSCTensor(node_name, t_info->dtype, layout, shape, "", prims);
}
return MSCTensor(node_name, t_info->dtype, layout, shape);
};
// Gather outputs
Array<MSCTensor> outputs;
const auto& sinfo = GetStructInfo(expr);
Array<String> layouts = StringUtils::Split(layout, ",");
size_t num_output = 1;
if (const auto* tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
num_output = tuple_sinfo->fields.size();
}
if (layouts.size() == 0) {
layouts = Array<String>(num_output, "");
}
ICHECK_EQ(layouts.size(), num_output)
<< "Layouts " << layouts << " msimatch with output size " << num_output;
if (sinfo->IsInstance<TensorStructInfoNode>()) {
const auto& t_name = node_name + ":" + std::to_string(0);
outputs.push_back(build_output(sinfo, t_name, layouts[0]));
} else if (const auto* s_sinfo = sinfo.as<ShapeStructInfoNode>()) {
Array<Integer> shape{s_sinfo->ndim};
const auto& t_name = node_name + ":" + std::to_string(0);
const auto& dtype = DataType(runtime::StringToDLDataType("int32"));
outputs.push_back(MSCTensor(t_name, dtype, layouts[0], shape));
} else if (const auto* tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
size_t field_size = optype == "nn.batch_norm" ? 1 : num_output;
for (size_t i = 0; i < field_size; i++) {
const auto& t_name = node_name + ":" + std::to_string(i);
outputs.push_back(build_output(tuple_sinfo->fields[i], t_name, layouts[i]));
}
} else {
LOG(FATAL) << "Unexpected struct info (" << sinfo->GetTypeKey() << ")" << sinfo;
}
// Build node
Array<String> scope;
if (optype != "input" && optype != "constant") {
scope = StringUtils::Split(scope_name_, ".");
}
const auto& shared_ref = SpanUtils::GetAttr(expr->span, msc_attr::kSharedRef);
const auto& node = MSCJoint(nodes_.size(), node_name, shared_ref, optype, attrs, scope, inputs,
outputs, node_weights);
Array<String> output_names;
for (size_t i = 0; i < outputs.size(); i++) {
output_names.push_back(outputs[i]->name);
tensor_input_map_[outputs[i]->name] = std::make_pair(node, i);
}
nodes_.push_back(node);
const auto& ref_expr = binding_var.defined() ? binding_var.value() : expr;
expr_tensor_map_.Set(ref_expr, output_names);
return node;
}