void RelayVMModel::SetupVMModule()

in src/dlr_relayvm.cc [27:103]


void RelayVMModel::SetupVMModule(const std::vector<DLRModelElem>& model_elems) {
  // Set custom allocators in TVM.
  if (dlr::DLRAllocatorFunctions::GetMemalignFunction() &&
      dlr::DLRAllocatorFunctions::GetFreeFunction()) {
    auto* pf = tvm::runtime::Registry::Get("runtime.contrib.set_custom_cpu_allocator");
    if (pf) {
      (*pf)(reinterpret_cast<void*>(dlr::DLRAllocatorFunctions::GetMemalignFunction()),
            reinterpret_cast<void*>(dlr::DLRAllocatorFunctions::GetFreeFunction()));
    } else {
      LOG(WARNING) << "Custom allocator functions are not available. Using default allocators.";
    }
  } else if (dlr::DLRAllocatorFunctions::AnySet()) {
    LOG(WARNING) << "SetDLRCustomAllocatorFree() and SetDLRCustomAllocatorMemalign() must be set "
                    "to override TVM allocations. Using default allocators.";
  }

  std::string code_data;
  std::string model_lib_path;
  std::string metadata_data;
  for (DLRModelElem el : model_elems) {
    if (el.type == DLRModelElemType::RELAY_EXEC) {
      if (el.path != nullptr) {
        code_data = dlr::LoadFileToString(el.path, std::ios::binary);
      } else if (el.data != nullptr && el.data_size > 0) {
        code_data.assign(static_cast<const char*>(el.data), el.data_size);
      } else {
        throw dmlc::Error("Invalid RelayVM model element RELAY_EXEC");
      }
    } else if (el.type == DLRModelElemType::TVM_LIB) {
      if (el.path != nullptr) {
        model_lib_path = el.path;
      } else {
        throw dmlc::Error("Invalid RelayVM model element TVM_LIB. TVM_LIB must be a file path.");
      }
    } else if (el.type == DLRModelElemType::NEO_METADATA) {
      if (el.path != nullptr) {
        metadata_data = dlr::LoadFileToString(el.path);
      } else if (el.data != nullptr) {
        metadata_data = static_cast<const char*>(el.data);
      } else {
        throw dmlc::Error("Invalid model element NEO_METADATA");
      }
    }
  }
  if (code_data.empty() || model_lib_path.empty() || metadata_data.empty()) {
    throw dmlc::Error(
        "Invalid RelayVM model. Must have RELAY_EXEC, TVM_LIB and NEO_METADATA elements");
  }

  LoadJsonFromString(metadata_data, this->metadata_);
  ValidateDeviceTypeIfExists();
  // Override allocator - default is kPooled.
  const char* val = std::getenv("DLR_RELAYVM_ALLOCATOR");
  if ((metadata_.count("Model") && metadata_["Model"].count("RelayVMAllocator") &&
       metadata_["Model"]["RelayVMAllocator"].get<std::string>() == "naive") ||
      (val != nullptr && std::string(val) == "naive")) {
    allocator_type_ = tvm::runtime::vm::AllocatorType::kNaive;
  }

  tvm::runtime::Module lib = tvm::runtime::Module::LoadFromFile(model_lib_path);

  vm_executable_ =
      std::make_shared<tvm::runtime::Module>(tvm::runtime::vm::Executable::Load(code_data, lib));
  auto vm = tvm::runtime::make_object<tvm::runtime::vm::VirtualMachine>();
  vm->LoadExecutable(static_cast<tvm::runtime::vm::Executable*>(
      const_cast<tvm::runtime::Object*>(vm_executable_->get())));
  vm_module_ = std::make_shared<tvm::runtime::Module>(tvm::runtime::Module(vm));

  tvm::runtime::PackedFunc init = vm_module_->GetFunction("init");
  if (dev_.device_type == DLDeviceType::kDLCPU) {
    init(static_cast<int>(dev_.device_type), dev_.device_id, static_cast<int>(allocator_type_));
  } else {
    // CPU context also must be initialized because input/output data comes from CPU.
    init(static_cast<int>(dev_.device_type), dev_.device_id, static_cast<int>(allocator_type_),
         static_cast<int>(DLDeviceType::kDLCPU), 0, static_cast<int>(allocator_type_));
  }
}