in src/tir/transforms/make_packed_api.cc [195:396]
PrimFunc MakePackedAPI(PrimFunc func) {
auto global_symbol = RequiresPackedAPI(func);
if (!global_symbol.defined()) {
return func;
}
std::string name_hint = global_symbol.value();
Target target = [&]() {
auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget ("
<< tvm::attr::kTarget << "), but the function only has attributes " << func->attrs;
return opt.value();
}();
int target_device_type = target->GetTargetDeviceType();
// A function without a host target has already been lowered.
Target target_host;
if (auto opt = target->GetHost()) {
target_host = opt.value();
} else {
return func;
}
auto* func_ptr = func.CopyOnWrite();
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
// Data field definitions
// The packed fields
Var v_self_handle("self_handle", DataType::Handle());
Var v_packed_args("args", DataType::Handle());
Var v_num_packed_args("num_args", DataType::Int(32));
Var v_result("result", PointerType(PrimType(DataType::Void())));
// The device context
Var device_id("dev_id");
Integer device_type(target_device_type);
// seq_init gives sequence of initialization
// seq_check gives sequence of later checks after init
std::vector<Stmt> seq_init, seq_check, arg_buffer_declarations;
std::unordered_map<const VarNode*, PrimExpr> vmap;
ArgBinder binder(&vmap);
// ---------------------------
// local function definitions
// load i-th argument as type t
auto f_load_arg_value = [&](DataType arg_type, int i) {
Array<PrimExpr> call_args{v_packed_args, IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)};
// load 64 bit version
DataType api_type = APIType(arg_type);
PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
// cast to the target version.
if (api_type != arg_type) {
res = Cast(arg_type, res);
}
return res;
};
// Assert correct type codes for each argument. This must be done
// *before* any initialization steps produced by
// `binder.BindDLTensor()`. The validity of those initialization
// steps depends on the correct types being present, and must not
// occur before the type codes are actually checked.
seq_init.push_back(MakeAssertEQ(v_num_packed_args, num_args, [&]() -> std::string {
std::ostringstream error_message;
error_message << name_hint << ": num_args should be " << num_args;
return error_message.str();
}()));
if (num_args > 0) {
seq_init.push_back(MakeAssertNotNull(v_packed_args, name_hint + ": args pointer is NULL"));
}
// Need to delay binding of the buffers, in case some arguments also
// appear in the buffer.
std::vector<std::pair<PrimExpr, Var>> var_def;
std::vector<std::pair<Var, Buffer>> buffer_def;
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];
PrimExpr arg_value;
// type index checks
Var type_index(param->name_hint + ".type_index", DataType::Int(32));
seq_init.push_back(LetStmt(type_index,
tir::Call(DataType::Int(32), builtin::tvm_struct_get(),
{v_packed_args, IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}),
nop));
DataType dtype = param.dtype();
if (dtype.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_init.emplace_back(AssertStmt(type_index == ffi::TypeIndex::kTVMFFINone ||
type_index == ffi::TypeIndex::kTVMFFIOpaquePtr ||
type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr ||
type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin,
tvm::tir::StringImm(msg.str()), nop));
// if type_index is NDArray, we need to add the offset of the DLTensor header
// which always equals 16 bytes, this ensures that T.handle always shows up as a DLTensor*
arg_value = f_load_arg_value(param.dtype(), i);
PrimExpr handle_from_ndarray =
Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(),
{arg_value, IntImm(DataType::Int(32), 16)});
arg_value =
Select(type_index == ffi::TypeIndex::kTVMFFINDArray, handle_from_ndarray, arg_value);
} else if (dtype.is_bool()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be boolean";
seq_init.emplace_back(AssertStmt(
type_index == ffi::TypeIndex::kTVMFFIBool || type_index == ffi::TypeIndex::kTVMFFIInt,
tvm::tir::StringImm(msg.str()), nop));
arg_value = Cast(DataType::Bool(), f_load_arg_value(DataType::Int(64), i));
} else if (dtype.is_int() || dtype.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_init.emplace_back(AssertStmt(
type_index == ffi::TypeIndex::kTVMFFIInt || type_index == ffi::TypeIndex::kTVMFFIBool,
tvm::tir::StringImm(msg.str()), nop));
arg_value = f_load_arg_value(param.dtype(), i);
} else {
ICHECK(dtype.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_init.emplace_back(AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat ||
type_index == ffi::TypeIndex::kTVMFFIInt ||
type_index == ffi::TypeIndex::kTVMFFIBool,
tvm::tir::StringImm(msg.str()), nop));
// use select so we can also handle int conversion to bool
arg_value = tir::Select(
type_index == ffi::TypeIndex::kTVMFFIFloat,
/* true_value = */ f_load_arg_value(param.dtype(), i),
/* false_value = */ Cast(param.dtype(), f_load_arg_value(DataType::Int(64), i)));
}
var_def.emplace_back(arg_value, param);
if (func_ptr->buffer_map.count(param)) {
// buffer binding now depends on type index
// if the index is NDArray handle, we need to offset to get the DLTensor*
buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
}
}
// signature: (void* self, TVMFFIAny* packed_args, int num_args, TVMFFIAny* v_result)
Array<Var> args{v_self_handle, v_packed_args, v_num_packed_args, v_result};
// Arg definitions are defined before buffer binding to avoid the use before
// def errors.
//
// For example, for auto broadcasting, checks are required to guarantee that
// either 0 or the original stride will be correctly used. Checks here have
// to use the args that may have no let binding yet. Therefore, hoisting let
// binding for args before buffer declaration is needed.
for (const auto& [expr, param] : var_def) {
binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
}
for (const auto& [var, buffer] : buffer_def) {
binder.BindDLTensor(buffer, device_type, device_id, var, name_hint + "." + var->name_hint);
arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
}
func = WithAttrs(std::move(func),
{{tvm::attr::kCallingConv, static_cast<int>(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host}});
Stmt body = ReturnRewriter(v_result)(func_ptr->body);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
StringImm(name_hint + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
ObjectRef node = String("default");
seq_check.push_back(AttrStmt(node, attr::device_id, device_id, nop));
seq_check.push_back(AttrStmt(node, attr::device_type, device_type, nop));
if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) {
Stmt set_device =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(runtime::symbol::tvm_set_device), device_type, device_id}));
body = SeqStmt({set_device, body});
}
}
// Return error code of zero on success
body = SeqStmt({body, Evaluate(ret(Integer(0)))});
body = MergeNest(
{seq_init, binder.init_nest(), seq_check, binder.asserts(), arg_buffer_declarations}, body);
func_ptr->body = body;
func_ptr->params = args;
Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined
<< " are used, but are not passed in as API arguments";
func_ptr->buffer_map = Map<Var, Buffer>();
func_ptr->checked_type_ = func_ptr->func_type_annotation();
func_ptr->ret_type = PrimType(DataType::Int(32));
// return the function.
return func;
}