void TVMModel::SetupTVMModule()

in src/dlr_tvm.cc [29:138]


void TVMModel::SetupTVMModule(const std::vector<DLRModelElem>& model_elems) {
  // Set custom allocators in TVM.
  if (dlr::DLRAllocatorFunctions::GetMemalignFunction() &&
      dlr::DLRAllocatorFunctions::GetFreeFunction()) {
    auto* pf = tvm::runtime::Registry::Get("runtime.contrib.set_custom_cpu_allocator");
    if (pf) {
      (*pf)(reinterpret_cast<void*>(dlr::DLRAllocatorFunctions::GetMemalignFunction()),
            reinterpret_cast<void*>(dlr::DLRAllocatorFunctions::GetFreeFunction()));
    } else {
      LOG(WARNING) << "Custom allocator functions are not available. Using default allocators.";
    }
  } else if (dlr::DLRAllocatorFunctions::AnySet()) {
    LOG(WARNING) << "SetDLRCustomAllocatorFree() and SetDLRCustomAllocatorMemalign() must be set "
                    "to override TVM allocations. Using default allocators.";
  }

  std::string graph_str;
  DLRString params_str;
  const char* params_data = nullptr;
  size_t params_size = 0;
  std::string model_lib_path;
  std::string metadata_data;
  for (DLRModelElem el : model_elems) {
    if (el.type == DLRModelElemType::TVM_GRAPH) {
      if (el.path != nullptr) {
        graph_str = dlr::LoadFileToString(el.path);
      } else if (el.data != nullptr) {
        graph_str = static_cast<const char*>(el.data);
      } else {
        throw dmlc::Error("Invalid TVM model element TVM_GRAPH");
      }
    } else if (el.type == DLRModelElemType::TVM_PARAMS) {
      if (el.path != nullptr) {
        std::ifstream pstream(el.path, std::ios::in | std::ios::binary);
        DLRStringStream params_blob;
        params_blob << pstream.rdbuf();
        params_str = params_blob.str();
        params_data = params_str.data();
        params_size = params_str.size();
      } else if (el.data != nullptr && el.data_size > 0) {
        params_data = static_cast<const char*>(el.data);
        params_size = el.data_size;
      } else {
        throw dmlc::Error("Invalid TVM model element TVM_PARAMS");
      }
    } else if (el.type == DLRModelElemType::TVM_LIB) {
      if (el.path != nullptr) {
        model_lib_path = el.path;
      } else {
        throw dmlc::Error("Invalid TVM model element TVM_LIB. TVM_LIB must be a file path.");
      }
    } else if (el.type == DLRModelElemType::NEO_METADATA) {
      if (el.path != nullptr) {
        metadata_data = dlr::LoadFileToString(el.path);
      } else if (el.data != nullptr) {
        metadata_data = static_cast<const char*>(el.data);
      }
    }
  }
  if (graph_str.empty() || params_data == nullptr || params_size <= 0 || model_lib_path.empty()) {
    throw dmlc::Error("Invalid TVM model. Must have TVM_GRAPH, TVM_PARAMS and TVM_LIB elements");
  }
  if (!metadata_data.empty()) {
    LoadJsonFromString(metadata_data, this->metadata_);
    ValidateDeviceTypeIfExists();
  }

  tvm::runtime::Module module;
  module = tvm::runtime::Module::LoadFromFile(model_lib_path);

  tvm_graph_executor_ = tvm::runtime::make_object<tvm::runtime::GraphExecutor>();
  tvm_graph_executor_->Init(graph_str, module, {dev_}, nullptr);
  dmlc::MemoryFixedSizeStream strm(const_cast<char*>(params_data), params_size);
  tvm_graph_executor_->LoadParams(&strm);

  tvm_module_ = std::make_shared<tvm::runtime::Module>(tvm::runtime::Module(tvm_graph_executor_));

  // Get list of weights.
  weight_names_ = tvm_graph_executor_->GetWeightNames();
  num_weights_ = weight_names_.size();
  std::unordered_set<std::string> weight_names_set(weight_names_.begin(), weight_names_.end());
  // TVM inputs contains both inputs and weights.
  const auto num_inputs_weights = tvm_graph_executor_->NumInputs();
  // Filter out weights to get only inputs.
  for (int i = 0; i < num_inputs_weights; i++) {
    auto name = tvm_graph_executor_->GetInputName(i);
    if (weight_names_set.count(name) == 0) {
      input_names_.push_back(name);
    }
  }
  // Save the number of inputs
  num_inputs_ = input_names_.size();
  inputs_.resize(num_inputs_);
  input_types_.resize(num_inputs_);
  for (int i = 0; i < num_inputs_; i++) {
    inputs_[i] = tvm_graph_executor_->GetInput(i);
    input_types_[i] = tvm_graph_executor_->GetInputType(i);
  }

  // Get the number of output and reserve space to save output tensor
  // pointers.
  num_outputs_ = tvm_graph_executor_->NumOutputs();
  outputs_.resize(num_outputs_);
  output_types_.resize(num_outputs_);
  for (int i = 0; i < num_outputs_; i++) {
    outputs_[i] = tvm_graph_executor_->GetOutput(i);
    output_types_[i] = tvm_graph_executor_->GetOutputType(i);
  }
  UpdateInputShapes();
}