const MSCJoint GraphBuilder::AddNode()

in src/contrib/msc/core/ir/graph_builder.cc [256:585]


const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional<Expr>& binding_var,
                                     const String& name) {
  // Get optype, node_name and layout
  String node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, msc_attr::kName);
  String optype = "unknown";
  String layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout);
  if (func_params_.count(expr) && func_params_[expr]->IsInstance<ConstantNode>()) {
    node_name = SpanUtils::GetAttr(func_params_[expr]->span, msc_attr::kName);
    optype = "constant";
  } else if (expr->IsInstance<VarNode>()) {
    optype = "input";
  } else if (expr->IsInstance<ConstantNode>()) {
    optype = "constant";
  } else if (expr->IsInstance<ShapeExprNode>()) {
    optype = "shape";
  } else if (expr->IsInstance<TupleGetItemNode>()) {
    optype = "get_item";
  } else if (expr->IsInstance<TupleNode>()) {
    optype = "tuple";
  } else if (const auto* call_node = expr.as<CallNode>()) {
    if (const auto* op_node = call_node->op.as<OpNode>()) {
      if (op_node->name == "relax.call_dps_packed") {
        optype = Downcast<ExternFunc>(call_node->args[0])->global_symbol;
      } else {
        optype = StringUtils::Replace(op_node->name, "relax.", "");
      }
    } else if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
      const auto& func = Downcast<Function>(ref_module_->Lookup(v_node->name_hint));
      std::tie(node_name, optype, layout) = ParseFunc(func);
    } else if (call_node->op->IsInstance<VarNode>()) {
      ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op;
      std::tie(node_name, optype, layout) = ParseFunc(target_funcs_[call_node->op]);
    } else if (call_node->op->IsInstance<FunctionNode>()) {
      std::tie(node_name, optype, layout) = ParseFunc(Downcast<Function>(call_node->op));
    }
  }
  if (layouts_.count(node_name)) {
    layout = layouts_[node_name];
  }

  // specail case for tuple
  if (optype == "tuple" && expr->IsInstance<CallNode>() &&
      Downcast<Call>(expr)->op->IsInstance<VarNode>()) {
    const auto& call_node = Downcast<Call>(expr);
    ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op;
    const auto& tuple_func = target_funcs_[call_node->op];
    for (size_t i = 0; i < call_node->args.size(); i++) {
      expr_tensor_map_.Set(tuple_func->params[i], expr_tensor_map_[call_node->args[i]]);
    }
    VisitExpr(tuple_func);
    ICHECK(expr_tensor_map_.count(tuple_func->body->body))
        << "Can not find seqexpr body " << tuple_func->body->body;
    const auto& outputs = expr_tensor_map_[tuple_func->body->body];
    const auto& ref_expr = binding_var.defined() ? binding_var.value() : expr;
    expr_tensor_map_.Set(ref_expr, outputs);
    ICHECK(tensor_input_map_.count(outputs[0])) << "Can not find tensor " << outputs[0];
    return Downcast<MSCJoint>(tensor_input_map_[outputs[0]].first);
  }

  // get plugin
  const auto& plugin = IsPlugin(optype) ? GetPlugin(optype) : Plugin();

  // Extract normal attributes
  Map<String, String> attrs;
  if (plugin.defined()) {
    const auto& op = Downcast<Call>(expr)->op;
    if (target_funcs_.count(op)) {
      const auto& opattrs_opt = target_funcs_[op]->GetAttr<Array<String>>(msc_attr::kOpattrs);
      if (opattrs_opt.defined()) {
        const auto& opattrs = opattrs_opt.value();
        ICHECK_EQ(opattrs.size(), plugin->attrs.size())
            << "opattrs " << opattrs << " size mismatch with " << plugin->attrs.size();
        for (size_t i = 0; i < opattrs.size(); i++) {
          attrs.Set(plugin->attrs[i]->name, opattrs[i]);
        }
      }
    } else {
      const auto& args = GetPluginInputs(expr);
      for (size_t i = 0; i < plugin->attrs.size(); i++) {
        const auto& val = args[plugin->inputs.size() + i];
        attrs.Set(plugin->attrs[i]->name, StringUtils::ToString(val));
      }
    }
  } else if (const auto* call_node = expr.as<CallNode>()) {
    if (const auto* v_node = call_node->op.as<GlobalVarNode>()) {
      const auto& func = Downcast<Function>(ref_module_->Lookup(v_node->name_hint));
      const auto& name_opt = func->GetAttr<runtime::String>(relax::attr::kComposite);
      if (name_opt.defined()) {
        attrs = FuncAttrGetter().GetAttrs(func);
      }
    } else if (call_node->op->IsInstance<VarNode>()) {
      ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op;
      attrs = FuncAttrGetter().GetAttrs(target_funcs_[call_node->op]);
    } else if (call_node->op->IsInstance<FunctionNode>()) {
      attrs = FuncAttrGetter().GetAttrs(call_node->op);
    } else if (call_node->attrs.defined()) {
      AttrGetter getter(&attrs);
      const_cast<BaseAttrsNode*>(call_node->attrs.get())->VisitAttrs(&getter);
    }
  } else if (const auto* const_node = expr.as<ConstantNode>()) {
    if (const_node->is_scalar()) {
      attrs.Set("scalar", GetScalarStr(const_node->data, config_.float_precision));
    }
  } else if (const auto* shape_node = expr.as<ShapeExprNode>()) {
    attrs.Set("shape", StringUtils::ToString(shape_node->values));
  } else if (const auto* get_node = expr.as<TupleGetItemNode>()) {
    attrs.Set("index", std::to_string(get_node->index));
  }

  // Extract attributes from arguments
  Array<String> input_types;
  if (!plugin.defined() && expr->IsInstance<CallNode>()) {
    const auto& call = Downcast<Call>(expr);
    Array<String> values;
    if (call->op->IsInstance<VarNode>()) {
      ICHECK(target_funcs_.count(call->op)) << "Can not find target func: " << call->op;
      values = FuncValueGetter().GetValues(target_funcs_[call->op]);
    }
    input_types = ExprUtils::GetInputTypes(optype, call->args.size() + values.size(), true);
    for (size_t i = 0; i < call->args.size(); i++) {
      const auto& arg = call->args[i];
      if (const auto* s_node = arg.as<ShapeExprNode>()) {
        attrs.Set(input_types[i], StringUtils::ToString(s_node->values));
      } else if (func_params_.count(arg) && func_params_[arg]->IsInstance<ShapeExprNode>()) {
        const auto* s_node = func_params_[arg].as<ShapeExprNode>();
        attrs.Set(input_types[i], StringUtils::ToString(s_node->values));
        ignore_nodes_.insert(Downcast<Var>(arg)->name_hint());
      } else if (const auto* s_node = arg.as<PrimValueNode>()) {
        ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype
                                          << " should has special type, get " << input_types;
        attrs.Set(input_types[i], StringUtils::ToString(s_node->value));
      } else if (input_types[i] != "input" && arg->IsInstance<TupleNode>()) {
        attrs.Set(input_types[i], StringUtils::ToString(arg));
      }
    }
    for (size_t i = call->args.size(); i < input_types.size(); i++) {
      attrs.Set(input_types[i], values[i - call->args.size()]);
    }
  }

  // Build inputs and weights
  Array<String> input_names;
  Map<String, MSCTensor> node_weights;
  if (plugin.defined()) {
    const auto& call = Downcast<Call>(expr);
    if (call->args.size() == 1) {
      ICHECK(expr_tensor_map_.count(call->args[0]))
          << "Can not find tuple plugin input " << call->args[0];
      input_names = expr_tensor_map_[call->args[0]];
    } else {
      const auto& args = GetPluginInputs(expr);
      for (size_t i = 0; i < plugin->inputs.size(); i++) {
        ICHECK(expr_tensor_map_.count(args[i])) << "Can not find plugin input " << args[i];
        for (const auto& in_name : expr_tensor_map_[args[i]]) {
          input_names.push_back(in_name);
        }
      }
    }
  } else if (const auto* call_node = expr.as<CallNode>()) {
    for (size_t i = 0; i < call_node->args.size(); i++) {
      if (attrs.count(input_types[i])) {
        continue;
      }
      const auto& arg = call_node->args[i];
      Array<String> arg_names;
      if (expr_tensor_map_.count(arg)) {
        arg_names = expr_tensor_map_[arg];
      } else if (input_types[i] == "input" && arg->IsInstance<TupleNode>()) {
        const auto* tuple_node = arg.as<TupleNode>();
        for (const auto& f : tuple_node->fields) {
          ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f;
          for (const auto& in_name : expr_tensor_map_[f]) {
            arg_names.push_back(in_name);
          }
        }
      }
      String weight_name;
      if (input_types[i] != "input" && arg->IsInstance<ConstantNode>()) {
        weight_name = SpanUtils::GetAttr(arg->span, msc_attr::kName);
      } else if (input_types[i] != "input" && func_params_.count(arg) &&
                 func_params_[arg]->IsInstance<ConstantNode>()) {
        weight_name = SpanUtils::GetAttr(func_params_[arg]->span, msc_attr::kName);
        ignore_nodes_.insert(Downcast<Var>(arg)->name_hint());
      }
      // set weights or inputs
      if (weight_name.size() > 0) {
        const auto& t_name = arg_names[0];
        const auto& pair = tensor_input_map_[t_name];
        const auto& producer = Downcast<MSCJoint>(pair.first);
        if (!weights_.count(weight_name)) {
          const auto& ref = producer->OutputAt(pair.second);
          MSCTensor weight;
          if (input_types[i] == "bias") {
            weight = MSCTensor(weight_name, ref->dtype, "O", Array<Integer>{ref->GetSize()});
          } else if (input_types[i] == "weight" &&
                     (optype == "msc.linear" || optype == "msc.linear_bias")) {
            if (ref->layout.name() == "IO") {
              String valid_layout = ref->layout[1].name() + ref->layout[0].name();
              const auto& valid_shape = Array<Integer>({ref->shape[1], ref->shape[0]});
              weight = MSCTensor(weight_name, ref->dtype, valid_layout, valid_shape);
            } else {
              weight = MSCTensor(weight_name, ref->dtype, ref->layout.name(), ref->shape);
            }
          } else {
            weight = MSCTensor(weight_name, ref->dtype, ref->layout.name(), ref->shape);
          }
          weights_.Set(weight_name, weight);
        }
        if (producer->HasAttr("scalar")) {
          attrs.Set(input_types[i], producer->GetTypeAttr<std::string>("scalar"));
        }
        node_weights.Set(input_types[i], weights_[weight_name]);
      } else {
        for (const auto& in_name : arg_names) {
          input_names.push_back(in_name);
        }
      }
    }
  } else if (const auto* tuple_node = expr.as<TupleNode>()) {
    for (const auto& f : tuple_node->fields) {
      ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f;
      for (const auto& in_name : expr_tensor_map_[f]) {
        input_names.push_back(in_name);
      }
    }
  } else if (const auto* getitem_node = expr.as<TupleGetItemNode>()) {
    ICHECK(expr_tensor_map_.count(getitem_node->tuple))
        << "Can not find tuple " << getitem_node->tuple;
    input_names = expr_tensor_map_[getitem_node->tuple];
  } else if (optype == "constant") {
    const auto& t_info = Downcast<TensorStructInfo>(GetStructInfo(expr));
    const auto& shape_opt = t_info->GetShape();
    ICHECK(shape_opt.defined()) << "Constant shape is not defined";
    const auto& weight =
        MSCTensor(node_name, t_info->dtype, layout, ArrayUtils::Cast<Integer>(shape_opt.value()));
    node_weights.Set("const", weight);
  }
  std::vector<std::pair<BaseJoint, size_t>> inputs;
  for (const auto& i : input_names) {
    inputs.push_back(tensor_input_map_[i]);
  }

  // Redefine layout for special ops
  if (optype == "tuple") {
    layout = "";
    for (size_t i = 0; i < inputs.size(); i++) {
      const auto& in_tensor = Downcast<MSCJoint>(inputs[i].first)->OutputAt(inputs[i].second);
      layout = layout + in_tensor->layout.name();
      layout = layout + (i == inputs.size() - 1 ? "" : ",");
    }
  } else if (optype == "get_item") {
    int idx = std::stoi(attrs["index"]);
    const auto& in_tensor = Downcast<MSCJoint>(inputs[idx].first)->OutputAt(inputs[idx].second);
    layout = in_tensor->layout.name();
  }

  // Build output tensor
  auto build_output = [this](const StructInfo& sinfo, const String& node_name,
                             const String& layout) {
    ICHECK(sinfo->IsInstance<TensorStructInfoNode>())
        << "sinfo should be TensorStructInfo, get " << sinfo->GetTypeKey();
    const auto& t_info = Downcast<TensorStructInfo>(sinfo);
    const auto& shape = ArrayUtils::Cast<Integer>(ExprUtils::GetShape(t_info));
    Array<String> prims;
    bool has_prims = false;
    if (shape.size() > 0) {
      for (const auto& s : t_info->GetShape().value()) {
        if (prim_map_.count(s)) {
          prims.push_back(prim_map_[s]->name);
          has_prims = true;
        } else {
          prims.push_back(StringUtils::ToString(s));
        }
      }
    }
    if (has_prims) {
      return MSCTensor(node_name, t_info->dtype, layout, shape, "", prims);
    }
    return MSCTensor(node_name, t_info->dtype, layout, shape);
  };

  // Gather outputs
  Array<MSCTensor> outputs;
  const auto& sinfo = GetStructInfo(expr);
  Array<String> layouts = StringUtils::Split(layout, ",");
  size_t num_output = 1;
  if (const auto* tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
    num_output = tuple_sinfo->fields.size();
  }
  if (layouts.size() == 0) {
    layouts = Array<String>(num_output, "");
  }
  ICHECK_EQ(layouts.size(), num_output)
      << "Layouts " << layouts << " msimatch with output size " << num_output;
  if (sinfo->IsInstance<TensorStructInfoNode>()) {
    const auto& t_name = node_name + ":" + std::to_string(0);
    outputs.push_back(build_output(sinfo, t_name, layouts[0]));
  } else if (const auto* s_sinfo = sinfo.as<ShapeStructInfoNode>()) {
    Array<Integer> shape{s_sinfo->ndim};
    const auto& t_name = node_name + ":" + std::to_string(0);
    const auto& dtype = DataType(runtime::StringToDLDataType("int32"));
    outputs.push_back(MSCTensor(t_name, dtype, layouts[0], shape));
  } else if (const auto* tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
    size_t field_size = optype == "nn.batch_norm" ? 1 : num_output;
    for (size_t i = 0; i < field_size; i++) {
      const auto& t_name = node_name + ":" + std::to_string(i);
      outputs.push_back(build_output(tuple_sinfo->fields[i], t_name, layouts[i]));
    }
  } else {
    LOG(FATAL) << "Unexpected struct info (" << sinfo->GetTypeKey() << ")" << sinfo;
  }

  // Build node
  Array<String> scope;
  if (optype != "input" && optype != "constant") {
    scope = StringUtils::Split(scope_name_, ".");
  }
  const auto& shared_ref = SpanUtils::GetAttr(expr->span, msc_attr::kSharedRef);
  const auto& node = MSCJoint(nodes_.size(), node_name, shared_ref, optype, attrs, scope, inputs,
                              outputs, node_weights);
  Array<String> output_names;
  for (size_t i = 0; i < outputs.size(); i++) {
    output_names.push_back(outputs[i]->name);
    tensor_input_map_[outputs[i]->name] = std::make_pair(node, i);
  }
  nodes_.push_back(node);
  const auto& ref_expr = binding_var.defined() ? binding_var.value() : expr;
  expr_tensor_map_.Set(ref_expr, output_names);
  return node;
}