class CodeGenVMTIR : public ExprFunctor()

in src/relax/backend/vm/codegen_vm_tir.cc [52:519]


class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> {
 public:
  explicit CodeGenVMTIR(relax::ExecBuilder builder, IRModule ctx_mod)
      : builder_(builder), ctx_mod_(ctx_mod) {
    system_lib_prefix_ = ctx_mod_->GetAttr<String>(tvm::attr::kSystemLibPrefix);
  }

  static IRModule Run(relax::ExecBuilder builder, IRModule mod) {
    // create a new copy
    IRModule res_mod = mod;
    res_mod.CopyOnWrite();

    CodeGenVMTIR codegen(builder, mod);
    // Remove relax function and turn into TIR func.
    for (auto& p : mod->functions) {
      if (auto* func = p.second.as<FunctionNode>()) {
        auto tir_func = codegen.Codegen(GetRef<Function>(func));
        auto gsymbol = tir_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
        res_mod->Add(GlobalVar(gsymbol.value()), tir_func);
        res_mod->Remove(p.first);
      }
    }
    return res_mod;
  }

 private:
  int64_t NewRegister() { return registers_num_++; }

  static IntImm ConstInt64(int64_t value) { return IntImm(DataType::Int(64), value); }

  static IntImm ConstInt32(int64_t value) { return IntImm(DataType::Int(32), value); }

  PrimExpr RegListGet(int64_t slot) const {
    // use 128 bits to represent any
    return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(),
                     {reg_anylist_handle_, ConstInt32(slot)});
  }

  PrimExpr ConstListGet(int64_t slot) const {
    // use 128 bits to represent any
    return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(),
                     {const_anylist_handle_, ConstInt32(slot)});
  }

  PrimExpr FuncListGet(int64_t slot) const {
    // use 128 bits to represent any
    return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(),
                     {func_anylist_handle_, ConstInt32(slot)});
  }

  void EmitStmt(tir::Stmt stmt) {
    ICHECK(!stmt_stack_.empty());
    stmt_stack_.back().emplace_back(stmt);
  }

  void EmitCallPacked(String name, const Array<PrimExpr>& args, int64_t dst_anylist_slot = -1) {
    Array<PrimExpr> all_args;
    // negative index indicate return value can be discarded, emit call_packed
    if (dst_anylist_slot >= 0) {
      all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)};
    }
    all_args.push_back(tir::StringImm(name));
    for (PrimExpr arg : args) {
      all_args.push_back(arg);
    }
    if (dst_anylist_slot >= 0) {
      this->EmitStmt(tir::Evaluate(
          tir::Call(DataType::Int(32), tir::builtin::anylist_setitem_call_packed(), all_args)));
    } else {
      this->EmitStmt(
          tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), all_args)));
    }
  }

  void EmitCallCPacked(const tir::PrimFunc& prim_func, const Array<PrimExpr>& args,
                       int64_t dst_anylist_slot = -1) {
    Optional<String> gsymbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    ICHECK(gsymbol.defined()) << "All functions must have global symbol at this phase";
    Array<PrimExpr> all_args;
    // negative index indicate return value can be discarded, emit call_packed
    if (dst_anylist_slot >= 0) {
      all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)};
    }
    all_args.push_back(tir::StringImm(gsymbol.value()));
    for (PrimExpr arg : args) {
      all_args.push_back(arg);
    }
    if (dst_anylist_slot >= 0) {
      this->EmitStmt(tir::Evaluate(
          tir::Call(DataType::Int(32), tir::builtin::anylist_setitem_call_cpacked(), all_args)));
    } else {
      this->EmitStmt(
          tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_call_cpacked(), all_args)));
    }
  }

  tir::PrimFunc Codegen(const Function& func) {
    Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. "
                                 "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?";
    // initialize the state
    stmt_stack_ = {};
    registers_num_ = 0;
    var_map_.clear();
    ctx_ptr_ = tir::Var("ctx_ptr", DataType::Handle());
    reg_anylist_handle_ = tir::Var("r", DataType::Handle());
    func_anylist_handle_ = tir::Var("f", DataType::Handle());
    const_anylist_handle_ = tir::Var("c", DataType::Handle());

    Array<String> param_names;
    for (Var param : func->params) {
      param_names.push_back(param->name_hint());
    }
    // declare this function.
    builder_->DeclareFunction(gsymbol.value(), vm::VMFuncInfo::FuncKind::kVMTIRFunc);

    for (size_t i = 0; i < func->params.size(); ++i) {
      int64_t r = NewRegister();
      ICHECK_EQ(static_cast<size_t>(r), i);
      this->var_map_.insert({func->params[i], RegListGet(r)});
    }
    size_t ret_reg = NewRegister();

    tir::Stmt body = WithNewScope([&]() {
      Optional<PrimExpr> ret = ExprFunctor::VisitExpr(func->body);
      if (ret.defined()) {
        this->EmitCallPacked("vm.builtin.copy", {ret.value()}, ret_reg);
      }
    });

    // Mark the function entry internally.
    builder_->EmitFunction(gsymbol.value(), param_names.size(), param_names,
                           VMFuncInfo::FuncKind::kVMTIRFunc, registers_num_);
    builder_->EndFunction(gsymbol.value());

    Type ret_type = VoidType();
    Array<tir::Var> tir_params = {ctx_ptr_, reg_anylist_handle_, const_anylist_handle_,
                                  func_anylist_handle_};
    String tir_func_name = system_lib_prefix_.value_or("") + "__vmtir__" + gsymbol.value();
    tir::PrimFunc tir_func(tir_params, body, ret_type, {});
    tir_func = WithAttr(tir_func, "global_symbol", tir_func_name);
    registers_num_ = 0;
    var_map_.clear();
    stmt_stack_.clear();
    return tir_func;
  }

  Optional<PrimExpr> VisitExpr_(const SeqExprNode* op) final {
    for (auto block : op->blocks) {
      for (Binding binding : block->bindings) {
        Expr expr = GetBoundValue(binding);
        Optional<PrimExpr> value = VisitExpr(expr);

        if (expr.as<Var>() && value.defined()) {
          // For a normalized relax module, there should be one
          // register for each relax::Binding.  This makes the Relax
          // semantics of R.vm.kill_* operate the same as the Python
          // "del" operator.  These bindings may be removable by using
          // relax.transform.CanonicalizeBindings earlier in lowering.
          auto new_reg = NewRegister();
          EmitCallPacked("vm.builtin.copy", {value.value()}, new_reg);
          value = RegListGet(new_reg);
        }

        this->var_map_.insert({binding->var, value});
      }
    }
    return this->VisitExpr(op->body);
  }

  Optional<PrimExpr> VisitExpr_(const CallNode* call_node) final {
    Call call = GetRef<Call>(call_node);

    if (call_node->op == null_value_op_) {
      return tir::Call(DataType::Handle(), tir::builtin::reinterpret(),
                       {IntImm(DataType::Int(64), 0)});
    }
    int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister();
    if (call->op.as<OpNode>()) {
      if (call_node->op == call_builtin_with_ctx_op_) {
        EmitCallBuiltinWithCtx(call, dst_reg);
      } else if (call_node->op == alloc_storage_op_) {
        EmitAllocStorage(call, dst_reg);
      } else if (call_node->op == alloc_tensor_op_) {
        EmitAllocTensor(call, dst_reg);
      } else if (call_node->op == kill_object_op_) {
        dst_reg = EmitKillObject(call);
      } else {
        // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those
        // ops are handled in a pass when lowering them to TIR.
        LOG(FATAL) << "CodeGenVMTIR cannot handle this intrinsic now:\n" << call_node->op;
      }
    } else {
      EmitNormalCall(call, dst_reg);
    }
    if (dst_reg >= 0) {
      return RegListGet(dst_reg);
    } else {
      return NullOpt;
    }
  }

  Optional<PrimExpr> VisitExpr_(const IfNode* op) final {
    // Reserve a register for return
    size_t merge_register = NewRegister();
    PrimExpr cond_value = this->VisitExpr(op->cond).value();

    cond_value = tir::Call(DataType::Bool(), tir::builtin::tvm_call_packed(),
                           {tir::StringImm("vm.builtin.read_if_cond"), cond_value});

    tir::Stmt true_branch = WithNewScope([&]() {
      PrimExpr true_value = this->VisitExpr(op->true_branch).value();
      this->EmitCallPacked("vm.builtin.copy", {true_value}, merge_register);
    });
    tir::Stmt false_branch = WithNewScope([&]() {
      PrimExpr false_value = this->VisitExpr(op->false_branch).value();
      this->EmitCallPacked("vm.builtin.copy", {false_value}, merge_register);
    });
    this->EmitStmt(tir::IfThenElse(cond_value, true_branch, false_branch));
    return RegListGet(merge_register);
  }

  Optional<PrimExpr> VisitExpr_(const VarNode* op) final {
    Var var = GetRef<Var>(op);
    auto it = this->var_map_.find(var);
    ICHECK(it != this->var_map_.end()) << "Var " << var << " is not defined";
    return it->second;
  }

  Optional<PrimExpr> VisitExpr_(const ConstantNode* op) final {
    return ConstListGet(builder_->ConvertConstant(op->data).value());
  }

  Optional<PrimExpr> VisitExpr_(const ShapeExprNode* op) final {
    std::vector<int64_t> shape;
    for (PrimExpr e : op->values) {
      if (auto* int_value = e.as<IntImmNode>()) {
        shape.push_back(int_value->value);
      } else {
        LOG(FATAL) << "Should only use constant shape after shape lowering: " << op->values;
      }
    }
    return ConstListGet(builder_->ConvertConstant(ShapeTuple(shape)).value());
  }

  Optional<PrimExpr> VisitExpr_(const PrimValueNode* op) final { return op->value; }

  Optional<PrimExpr> VisitExpr_(const StringImmNode* op) final {
    return ConstListGet(builder_->ConvertConstant(op->value).value());
  }

  Optional<PrimExpr> VisitExpr_(const DataTypeImmNode* op) final {
    return ConstListGet(builder_->ConvertConstant(op->value).value());
  }

  Optional<PrimExpr> VisitExpr_(const TupleNode* op) final {
    Tuple tuple = GetRef<Tuple>(op);
    Array<PrimExpr> args;
    for (auto arg : tuple->fields) {
      args.push_back(this->VisitExpr(arg).value());
    }
    int32_t dst_register = NewRegister();
    this->EmitCallPacked("vm.builtin.make_tuple", args, dst_register);
    return RegListGet(dst_register);
  }

  Optional<PrimExpr> VisitExpr_(const TupleGetItemNode* op) final {
    TupleGetItem expr = GetRef<TupleGetItem>(op);
    Array<PrimExpr> args = {this->VisitExpr(expr->tuple).value()};

    args.push_back(ConstInt64(expr->index));

    int64_t dst_register = NewRegister();
    this->EmitCallPacked("vm.builtin.tuple_getitem", args, dst_register);
    return RegListGet(dst_register);
  }

  // Lookup the function and see if it matches
  Optional<String> LookupFunction(const Expr& expr, VMFuncInfo::FuncKind* kind) {
    if (auto* ext_func = expr.as<ExternFuncNode>()) {
      *kind = VMFuncInfo::FuncKind::kPackedFunc;
      return ext_func->global_symbol;
    } else if (auto* gvar_ptr = expr.as<GlobalVarNode>()) {
      GlobalVar gvar = GetRef<GlobalVar>(gvar_ptr);
      // Run a look up in the env to see if it maps to an extern func.
      auto it = ctx_mod_->functions.find(gvar);
      if (it != ctx_mod_->functions.end()) {
        BaseFunc func = (*it).second;
        if (auto* efunc = func.as<ExternFuncNode>()) {
          *kind = VMFuncInfo::FuncKind::kPackedFunc;
          return efunc->global_symbol;
        } else if (func.as<FunctionNode>()) {
          *kind = VMFuncInfo::FuncKind::kVMTIRFunc;
          return gvar->name_hint;
        } else if (func.as<tir::PrimFuncNode>()) {
          *kind = VMFuncInfo::FuncKind::kPackedFunc;
          return gvar->name_hint;
        } else {
          *kind = VMFuncInfo::FuncKind::kPackedFunc;
          return gvar->name_hint;
        }
      }
      LOG(WARNING) << "Undefined global var " << gvar->name_hint;
      // undefined global var, consider eliminate later.
      *kind = VMFuncInfo::FuncKind::kPackedFunc;
      return gvar->name_hint;
    } else {
      return NullOpt;
    }
  }
  // Lookup PrimFunc in the same module
  // We can do direct PrimFunc call in such cases
  Optional<tir::PrimFunc> LookupPrimFunc(const String& name) {
    if (!ctx_mod_->ContainGlobalVar(name)) return NullOpt;

    GlobalVar gvar = ctx_mod_->GetGlobalVar(name);
    auto it = ctx_mod_->functions.find(gvar);
    if (it != ctx_mod_->functions.end()) {
      BaseFunc func = (*it).second;
      if (auto* prim_func = func.as<tir::PrimFuncNode>()) {
        return GetRef<tir::PrimFunc>(prim_func);
      }
    }
    return NullOpt;
  }

  Optional<PrimExpr> VisitExpr_(const GlobalVarNode* op) final {
    VMFuncInfo::FuncKind kind;
    auto symbol = LookupFunction(GetRef<Expr>(op), &kind);
    ICHECK(symbol.defined());
    builder_->DeclareFunction(symbol.value(), kind);
    return FuncListGet(builder_->GetFunction(symbol.value()).value());
  }

  Optional<PrimExpr> VisitExpr_(const ExternFuncNode* op) final {
    builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc);
    return FuncListGet(builder_->GetFunction(op->global_symbol).value());
  }

  void EmitAllocStorage(const Call& call_node, int64_t dst_reg) {
    // Handle args of the call
    Array<PrimExpr> args;
    args.push_back(ctx_ptr_);
    for (Expr arg : call_node->args) {
      args.push_back(this->VisitExpr(arg).value());
    }
    this->EmitCallPacked("vm.builtin.alloc_storage", args, dst_reg);
  }

  void EmitAllocTensor(const Call& call_node, int64_t dst_reg) {
    ICHECK_EQ(call_node->args.size(), 4);
    Array<PrimExpr> args;
    args.reserve(4);
    for (Expr arg : call_node->args) {
      args.push_back(this->VisitExpr(arg).value());
    }
    this->EmitCallPacked("vm.builtin.alloc_tensor", args, dst_reg);
  }

  int64_t EmitKillObject(const Call& call_node) {
    ICHECK_EQ(call_node->args.size(), 1);
    PrimExpr arg = this->VisitExpr(call_node->args[0]).value();

    // Check the arg is a register.
    const auto* tir_call = arg.as<tir::CallNode>();
    ICHECK(tir_call != nullptr);
    ICHECK(tir_call->op == tir::builtin::anylist_getitem());
    ICHECK(tir_call->args.size() == 2);
    ICHECK(tir_call->args[0].same_as(reg_anylist_handle_));
    const auto* p_dst_reg = tir_call->args[1].as<tir::IntImmNode>();
    ICHECK(p_dst_reg != nullptr);
    ICHECK(p_dst_reg->dtype == DataType::Int(32));

    int64_t dst_reg = p_dst_reg->value;
    this->EmitCallPacked("vm.builtin.null_value", {}, dst_reg);
    return dst_reg;
  }

  void EmitCallBuiltinWithCtx(const Call& call_node, int64_t dst_reg) {
    Array<PrimExpr> args;
    // if context is required, pass as first argument.
    args.push_back(ctx_ptr_);
    auto* func = call_node->args[0].as<ExternFuncNode>();
    ICHECK(func) << "CallBuiltin comes with extern func";

    auto tuple_arg = Downcast<Tuple>(call_node->args[1]);

    // Handle args of the call
    for (Expr arg : tuple_arg->fields) {
      args.push_back(this->VisitExpr(arg).value());
    }

    this->EmitCallPacked(func->global_symbol, args, dst_reg);
  }

  void EmitNormalCall(const Call& call_node, int64_t dst_reg) {
    Array<PrimExpr> args = VisitArray(call_node->args);
    // A function can be a closure that comes from parent
    // Do call closure to be safe.
    VMFuncInfo::FuncKind kind;
    auto symbol = LookupFunction(call_node->op, &kind);

    if (symbol.defined() && kind == VMFuncInfo::FuncKind::kPackedFunc) {
      // primfunc in the same module.
      // use cpacked to directly invoke without named based lookup
      if (Optional<tir::PrimFunc> prim_func = LookupPrimFunc(symbol.value())) {
        this->EmitCallCPacked(prim_func.value(), args, dst_reg);
      } else {
        this->EmitCallPacked(symbol.value(), args, dst_reg);
      }
    } else {
      // Default path, leverage function table and invoke as closure
      Array<PrimExpr> all_args;
      all_args.push_back(ctx_ptr_);
      all_args.push_back(this->VisitExpr(call_node->op).value());
      for (auto arg : args) {
        all_args.push_back(arg);
      }
      this->EmitCallPacked("vm.builtin.invoke_closure", all_args, dst_reg);
    }
  }

  template <typename FLambda>
  tir::Stmt WithNewScope(const FLambda& callback) {
    stmt_stack_.push_back({});
    callback();
    tir::Stmt stmt = tir::SeqStmt::Flatten(stmt_stack_.back());
    stmt_stack_.pop_back();
    return stmt;
  }

  Array<PrimExpr> VisitArray(const Array<Expr>& arr) {
    Array<PrimExpr> ret;
    for (size_t i = 0; i < arr.size(); ++i) {
      ret.push_back(this->VisitExpr(arr[i]).value());
    }
    return ret;
  }
  /*! \brief Internal ExecBuilder. */
  relax::ExecBuilder builder_;
  /*! \brief List to ctx_ptr */
  tir::Var ctx_ptr_;
  /*! \brief List to store temp object registers */
  tir::Var reg_anylist_handle_;
  /*! \brief List to store closures */
  tir::Var func_anylist_handle_;
  /*! \brief List to store constants */
  tir::Var const_anylist_handle_;
  /*!
   * \brief Total number of virtual registers allocated.
   * \note The first two registers are reserved for special registers.
   */
  int64_t registers_num_ = 0;
  /*! \brief Stack to build up statements */
  std::vector<std::vector<tir::Stmt>> stmt_stack_;
  /*! \brief Map from var to Expr. */
  std::unordered_map<Var, Optional<PrimExpr>> var_map_;
  /*! \brief the context module. */
  IRModule ctx_mod_;
  /*! \brief system lib prefix */
  Optional<String> system_lib_prefix_;
  /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */
  const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
  const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
  const Op& kill_object_op_ = Op::Get("relax.vm.kill_object");
  const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
  const Op& null_value_op_ = Op::Get("relax.null_value");
};