in src/dlr_tvm.cc [29:138]
void TVMModel::SetupTVMModule(const std::vector<DLRModelElem>& model_elems) {
// Set custom allocators in TVM.
if (dlr::DLRAllocatorFunctions::GetMemalignFunction() &&
dlr::DLRAllocatorFunctions::GetFreeFunction()) {
auto* pf = tvm::runtime::Registry::Get("runtime.contrib.set_custom_cpu_allocator");
if (pf) {
(*pf)(reinterpret_cast<void*>(dlr::DLRAllocatorFunctions::GetMemalignFunction()),
reinterpret_cast<void*>(dlr::DLRAllocatorFunctions::GetFreeFunction()));
} else {
LOG(WARNING) << "Custom allocator functions are not available. Using default allocators.";
}
} else if (dlr::DLRAllocatorFunctions::AnySet()) {
LOG(WARNING) << "SetDLRCustomAllocatorFree() and SetDLRCustomAllocatorMemalign() must be set "
"to override TVM allocations. Using default allocators.";
}
std::string graph_str;
DLRString params_str;
const char* params_data = nullptr;
size_t params_size = 0;
std::string model_lib_path;
std::string metadata_data;
for (DLRModelElem el : model_elems) {
if (el.type == DLRModelElemType::TVM_GRAPH) {
if (el.path != nullptr) {
graph_str = dlr::LoadFileToString(el.path);
} else if (el.data != nullptr) {
graph_str = static_cast<const char*>(el.data);
} else {
throw dmlc::Error("Invalid TVM model element TVM_GRAPH");
}
} else if (el.type == DLRModelElemType::TVM_PARAMS) {
if (el.path != nullptr) {
std::ifstream pstream(el.path, std::ios::in | std::ios::binary);
DLRStringStream params_blob;
params_blob << pstream.rdbuf();
params_str = params_blob.str();
params_data = params_str.data();
params_size = params_str.size();
} else if (el.data != nullptr && el.data_size > 0) {
params_data = static_cast<const char*>(el.data);
params_size = el.data_size;
} else {
throw dmlc::Error("Invalid TVM model element TVM_PARAMS");
}
} else if (el.type == DLRModelElemType::TVM_LIB) {
if (el.path != nullptr) {
model_lib_path = el.path;
} else {
throw dmlc::Error("Invalid TVM model element TVM_LIB. TVM_LIB must be a file path.");
}
} else if (el.type == DLRModelElemType::NEO_METADATA) {
if (el.path != nullptr) {
metadata_data = dlr::LoadFileToString(el.path);
} else if (el.data != nullptr) {
metadata_data = static_cast<const char*>(el.data);
}
}
}
if (graph_str.empty() || params_data == nullptr || params_size <= 0 || model_lib_path.empty()) {
throw dmlc::Error("Invalid TVM model. Must have TVM_GRAPH, TVM_PARAMS and TVM_LIB elements");
}
if (!metadata_data.empty()) {
LoadJsonFromString(metadata_data, this->metadata_);
ValidateDeviceTypeIfExists();
}
tvm::runtime::Module module;
module = tvm::runtime::Module::LoadFromFile(model_lib_path);
tvm_graph_executor_ = tvm::runtime::make_object<tvm::runtime::GraphExecutor>();
tvm_graph_executor_->Init(graph_str, module, {dev_}, nullptr);
dmlc::MemoryFixedSizeStream strm(const_cast<char*>(params_data), params_size);
tvm_graph_executor_->LoadParams(&strm);
tvm_module_ = std::make_shared<tvm::runtime::Module>(tvm::runtime::Module(tvm_graph_executor_));
// Get list of weights.
weight_names_ = tvm_graph_executor_->GetWeightNames();
num_weights_ = weight_names_.size();
std::unordered_set<std::string> weight_names_set(weight_names_.begin(), weight_names_.end());
// TVM inputs contains both inputs and weights.
const auto num_inputs_weights = tvm_graph_executor_->NumInputs();
// Filter out weights to get only inputs.
for (int i = 0; i < num_inputs_weights; i++) {
auto name = tvm_graph_executor_->GetInputName(i);
if (weight_names_set.count(name) == 0) {
input_names_.push_back(name);
}
}
// Save the number of inputs
num_inputs_ = input_names_.size();
inputs_.resize(num_inputs_);
input_types_.resize(num_inputs_);
for (int i = 0; i < num_inputs_; i++) {
inputs_[i] = tvm_graph_executor_->GetInput(i);
input_types_[i] = tvm_graph_executor_->GetInputType(i);
}
// Get the number of output and reserve space to save output tensor
// pointers.
num_outputs_ = tvm_graph_executor_->NumOutputs();
outputs_.resize(num_outputs_);
output_types_.resize(num_outputs_);
for (int i = 0; i < num_outputs_; i++) {
outputs_[i] = tvm_graph_executor_->GetOutput(i);
output_types_[i] = tvm_graph_executor_->GetOutputType(i);
}
UpdateInputShapes();
}