PrimFunc MakePackedAPI()

in src/tir/transforms/make_packed_api.cc [195:396]


PrimFunc MakePackedAPI(PrimFunc func) {
  auto global_symbol = RequiresPackedAPI(func);
  if (!global_symbol.defined()) {
    return func;
  }
  std::string name_hint = global_symbol.value();

  Target target = [&]() {
    auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
    ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget ("
                << tvm::attr::kTarget << "), but the function only has attributes " << func->attrs;
    return opt.value();
  }();
  int target_device_type = target->GetTargetDeviceType();

  // A function without a host target has already been lowered.
  Target target_host;
  if (auto opt = target->GetHost()) {
    target_host = opt.value();
  } else {
    return func;
  }

  auto* func_ptr = func.CopyOnWrite();
  const Stmt nop = Evaluate(0);
  int num_args = static_cast<int>(func_ptr->params.size());

  // Data field definitions
  // The packed fields
  Var v_self_handle("self_handle", DataType::Handle());
  Var v_packed_args("args", DataType::Handle());
  Var v_num_packed_args("num_args", DataType::Int(32));
  Var v_result("result", PointerType(PrimType(DataType::Void())));

  // The device context
  Var device_id("dev_id");
  Integer device_type(target_device_type);
  // seq_init gives sequence of initialization
  // seq_check gives sequence of later checks after init
  std::vector<Stmt> seq_init, seq_check, arg_buffer_declarations;
  std::unordered_map<const VarNode*, PrimExpr> vmap;
  ArgBinder binder(&vmap);

  // ---------------------------
  // local function definitions
  // load i-th argument as type t
  auto f_load_arg_value = [&](DataType arg_type, int i) {
    Array<PrimExpr> call_args{v_packed_args, IntImm(DataType::Int(32), i),
                              IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)};
    // load 64 bit version
    DataType api_type = APIType(arg_type);
    PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
    // cast to the target version.
    if (api_type != arg_type) {
      res = Cast(arg_type, res);
    }
    return res;
  };

  // Assert correct type codes for each argument.  This must be done
  // *before* any initialization steps produced by
  // `binder.BindDLTensor()`.  The validity of those initialization
  // steps depends on the correct types being present, and must not
  // occur before the type codes are actually checked.
  seq_init.push_back(MakeAssertEQ(v_num_packed_args, num_args, [&]() -> std::string {
    std::ostringstream error_message;
    error_message << name_hint << ": num_args should be " << num_args;
    return error_message.str();
  }()));

  if (num_args > 0) {
    seq_init.push_back(MakeAssertNotNull(v_packed_args, name_hint + ": args pointer is NULL"));
  }

  // Need to delay binding of the buffers, in case some arguments also
  // appear in the buffer.
  std::vector<std::pair<PrimExpr, Var>> var_def;
  std::vector<std::pair<Var, Buffer>> buffer_def;

  for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
    Var param = func_ptr->params[i];
    PrimExpr arg_value;
    // type index checks
    Var type_index(param->name_hint + ".type_index", DataType::Int(32));
    seq_init.push_back(LetStmt(type_index,
                               tir::Call(DataType::Int(32), builtin::tvm_struct_get(),
                                         {v_packed_args, IntImm(DataType::Int(32), i),
                                          IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}),
                               nop));
    DataType dtype = param.dtype();
    if (dtype.is_handle()) {
      std::ostringstream msg;
      msg << name_hint << ": Expect arg[" << i << "] to be pointer";
      seq_init.emplace_back(AssertStmt(type_index == ffi::TypeIndex::kTVMFFINone ||
                                           type_index == ffi::TypeIndex::kTVMFFIOpaquePtr ||
                                           type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr ||
                                           type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin,
                                       tvm::tir::StringImm(msg.str()), nop));
      // if type_index is NDArray, we need to add the offset of the DLTensor header
      // which always equals 16 bytes, this ensures that T.handle always shows up as a DLTensor*
      arg_value = f_load_arg_value(param.dtype(), i);
      PrimExpr handle_from_ndarray =
          Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(),
               {arg_value, IntImm(DataType::Int(32), 16)});
      arg_value =
          Select(type_index == ffi::TypeIndex::kTVMFFINDArray, handle_from_ndarray, arg_value);
    } else if (dtype.is_bool()) {
      std::ostringstream msg;
      msg << name_hint << ": Expect arg[" << i << "] to be boolean";
      seq_init.emplace_back(AssertStmt(
          type_index == ffi::TypeIndex::kTVMFFIBool || type_index == ffi::TypeIndex::kTVMFFIInt,
          tvm::tir::StringImm(msg.str()), nop));
      arg_value = Cast(DataType::Bool(), f_load_arg_value(DataType::Int(64), i));

    } else if (dtype.is_int() || dtype.is_uint()) {
      std::ostringstream msg;
      msg << name_hint << ": Expect arg[" << i << "] to be int";
      seq_init.emplace_back(AssertStmt(
          type_index == ffi::TypeIndex::kTVMFFIInt || type_index == ffi::TypeIndex::kTVMFFIBool,
          tvm::tir::StringImm(msg.str()), nop));
      arg_value = f_load_arg_value(param.dtype(), i);
    } else {
      ICHECK(dtype.is_float());
      std::ostringstream msg;
      msg << name_hint << ": Expect arg[" << i << "] to be float";
      seq_init.emplace_back(AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat ||
                                           type_index == ffi::TypeIndex::kTVMFFIInt ||
                                           type_index == ffi::TypeIndex::kTVMFFIBool,
                                       tvm::tir::StringImm(msg.str()), nop));
      // use select so we can also handle int conversion to bool
      arg_value = tir::Select(
          type_index == ffi::TypeIndex::kTVMFFIFloat,
          /* true_value = */ f_load_arg_value(param.dtype(), i),
          /* false_value = */ Cast(param.dtype(), f_load_arg_value(DataType::Int(64), i)));
    }
    var_def.emplace_back(arg_value, param);
    if (func_ptr->buffer_map.count(param)) {
      // buffer binding now depends on type index
      // if the index is NDArray handle, we need to offset to get the DLTensor*
      buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
    }
  }

  // signature: (void* self, TVMFFIAny* packed_args, int num_args, TVMFFIAny* v_result)
  Array<Var> args{v_self_handle, v_packed_args, v_num_packed_args, v_result};

  // Arg definitions are defined before buffer binding to avoid the use before
  // def errors.
  //
  // For example, for auto broadcasting, checks are required to guarantee that
  // either 0 or the original stride will be correctly used. Checks here have
  // to use the args that may have no let binding yet. Therefore, hoisting let
  // binding for args before buffer declaration is needed.
  for (const auto& [expr, param] : var_def) {
    binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
  }

  for (const auto& [var, buffer] : buffer_def) {
    binder.BindDLTensor(buffer, device_type, device_id, var, name_hint + "." + var->name_hint);
    arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
  }

  func = WithAttrs(std::move(func),
                   {{tvm::attr::kCallingConv, static_cast<int>(CallingConv::kCPackedFunc)},
                    {tvm::attr::kTarget, target_host}});

  Stmt body = ReturnRewriter(v_result)(func_ptr->body);
  body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
                  StringImm(name_hint + "_compute_"), body);
  // Set device context
  if (vmap.count(device_id.get())) {
    ObjectRef node = String("default");
    seq_check.push_back(AttrStmt(node, attr::device_id, device_id, nop));
    seq_check.push_back(AttrStmt(node, attr::device_type, device_type, nop));

    if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) {
      Stmt set_device =
          Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
                        {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}));
      body = SeqStmt({set_device, body});
    }
  }

  // Return error code of zero on success
  body = SeqStmt({body, Evaluate(ret(Integer(0)))});

  body = MergeNest(
      {seq_init, binder.init_nest(), seq_check, binder.asserts(), arg_buffer_declarations}, body);
  func_ptr->body = body;
  func_ptr->params = args;

  Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
  ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined
                                 << " are used, but are not passed in as API arguments";

  func_ptr->buffer_map = Map<Var, Buffer>();
  func_ptr->checked_type_ = func_ptr->func_type_annotation();
  func_ptr->ret_type = PrimType(DataType::Int(32));

  // return the function.
  return func;
}