in src/dlr_relayvm.cc [27:103]
void RelayVMModel::SetupVMModule(const std::vector<DLRModelElem>& model_elems) {
// Set custom allocators in TVM.
if (dlr::DLRAllocatorFunctions::GetMemalignFunction() &&
dlr::DLRAllocatorFunctions::GetFreeFunction()) {
auto* pf = tvm::runtime::Registry::Get("runtime.contrib.set_custom_cpu_allocator");
if (pf) {
(*pf)(reinterpret_cast<void*>(dlr::DLRAllocatorFunctions::GetMemalignFunction()),
reinterpret_cast<void*>(dlr::DLRAllocatorFunctions::GetFreeFunction()));
} else {
LOG(WARNING) << "Custom allocator functions are not available. Using default allocators.";
}
} else if (dlr::DLRAllocatorFunctions::AnySet()) {
LOG(WARNING) << "SetDLRCustomAllocatorFree() and SetDLRCustomAllocatorMemalign() must be set "
"to override TVM allocations. Using default allocators.";
}
std::string code_data;
std::string model_lib_path;
std::string metadata_data;
for (DLRModelElem el : model_elems) {
if (el.type == DLRModelElemType::RELAY_EXEC) {
if (el.path != nullptr) {
code_data = dlr::LoadFileToString(el.path, std::ios::binary);
} else if (el.data != nullptr && el.data_size > 0) {
code_data.assign(static_cast<const char*>(el.data), el.data_size);
} else {
throw dmlc::Error("Invalid RelayVM model element RELAY_EXEC");
}
} else if (el.type == DLRModelElemType::TVM_LIB) {
if (el.path != nullptr) {
model_lib_path = el.path;
} else {
throw dmlc::Error("Invalid RelayVM model element TVM_LIB. TVM_LIB must be a file path.");
}
} else if (el.type == DLRModelElemType::NEO_METADATA) {
if (el.path != nullptr) {
metadata_data = dlr::LoadFileToString(el.path);
} else if (el.data != nullptr) {
metadata_data = static_cast<const char*>(el.data);
} else {
throw dmlc::Error("Invalid model element NEO_METADATA");
}
}
}
if (code_data.empty() || model_lib_path.empty() || metadata_data.empty()) {
throw dmlc::Error(
"Invalid RelayVM model. Must have RELAY_EXEC, TVM_LIB and NEO_METADATA elements");
}
LoadJsonFromString(metadata_data, this->metadata_);
ValidateDeviceTypeIfExists();
// Override allocator - default is kPooled.
const char* val = std::getenv("DLR_RELAYVM_ALLOCATOR");
if ((metadata_.count("Model") && metadata_["Model"].count("RelayVMAllocator") &&
metadata_["Model"]["RelayVMAllocator"].get<std::string>() == "naive") ||
(val != nullptr && std::string(val) == "naive")) {
allocator_type_ = tvm::runtime::vm::AllocatorType::kNaive;
}
tvm::runtime::Module lib = tvm::runtime::Module::LoadFromFile(model_lib_path);
vm_executable_ =
std::make_shared<tvm::runtime::Module>(tvm::runtime::vm::Executable::Load(code_data, lib));
auto vm = tvm::runtime::make_object<tvm::runtime::vm::VirtualMachine>();
vm->LoadExecutable(static_cast<tvm::runtime::vm::Executable*>(
const_cast<tvm::runtime::Object*>(vm_executable_->get())));
vm_module_ = std::make_shared<tvm::runtime::Module>(tvm::runtime::Module(vm));
tvm::runtime::PackedFunc init = vm_module_->GetFunction("init");
if (dev_.device_type == DLDeviceType::kDLCPU) {
init(static_cast<int>(dev_.device_type), dev_.device_id, static_cast<int>(allocator_type_));
} else {
// CPU context also must be initialized because input/output data comes from CPU.
init(static_cast<int>(dev_.device_type), dev_.device_id, static_cast<int>(allocator_type_),
static_cast<int>(DLDeviceType::kDLCPU), 0, static_cast<int>(allocator_type_));
}
}