src/runtime/relax_vm/vm.cc (740 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/runtime/relax_vm/vm.cc */ #include <dlpack/dlpack.h> #include <tvm/runtime/memory/memory_manager.h> #include <tvm/runtime/nvtx.h> #include <tvm/runtime/packed_func.h> #include <tvm/runtime/profiling.h> #include <tvm/runtime/relax_vm/vm.h> #include <optional> #include <thread> namespace tvm { namespace runtime { namespace relax_vm { //--------------------------------------------- // VM Closure object //--------------------------------------------- TVM_REGISTER_OBJECT_TYPE(VMClosureObj); VMClosure::VMClosure(String func_name, PackedFunc impl) { auto ptr = make_object<VMClosureObj>(); ptr->func_name = func_name; ptr->impl = std::move(impl); data_ = std::move(ptr); } /*! * \brief Create another PackedFunc with last arguments already bound to last_args. * \param func The input func, can be a VMClosure or PackedFunc. * \param last_args The arguments to bound to in the end of the function. * \note The new function takes in arguments and append the last_args in the end. */ PackedFunc VMClosure::BindLastArgs(PackedFunc func, std::vector<Any> last_args) { return PackedFunc([func, last_args](TVMArgs args, TVMRetValue* rv) { std::vector<AnyView> packed_args(args.size() + last_args.size()); std::copy(args.data(), args.data() + args.size(), packed_args.data()); for (size_t i = 0; i < last_args.size(); ++i) { packed_args[args.size() + i] = last_args[i]; } func.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), rv); }); } //----------------------------------------------------------- // Utility functions. //----------------------------------------------------------- // Use the args after `starting_arg_idx` as a series of indices into `obj`, // indexing into nested Array and returning the final indexed object. Any IndexIntoNestedObject(Any obj, TVMArgs args, int starting_arg_idx) { for (int i = starting_arg_idx; i < args.size(); i++) { // the object must be an Array to be able to index into it if (!obj.as<ffi::ArrayObj>()) { LOG(FATAL) << "ValueError: Attempted to index into an object that is not an Array."; } int index = args[i].cast<int>(); auto arr = Downcast<ffi::Array<Any>>(obj); // make sure the index is in bounds if (index >= static_cast<int>(arr.size())) { LOG(FATAL) << "IndexError: Invalid index (" << index << " >= " << arr.size() << ")."; } obj = arr[index]; } return obj; } NDArray ConvertNDArrayToDevice(NDArray src, const DLDevice& dev, Allocator* alloc) { if (src->device.device_type == dev.device_type && src->device.device_id == dev.device_id) { return src; } else { auto res = alloc->Empty(src.Shape(), src->dtype, dev); res.CopyFrom(src); return res; } } Any ConvertObjectToDevice(Any src, const Device& dev, Allocator* alloc) { if (src.as<NDArray::ContainerType>()) { return ConvertNDArrayToDevice(Downcast<NDArray>(src), dev, alloc); } else if (src.as<ffi::ArrayObj>()) { std::vector<Any> ret; auto arr = Downcast<ffi::Array<Any>>(src); for (size_t i = 0; i < arr.size(); i++) { ret.push_back(ConvertObjectToDevice(arr[i], dev, alloc)); } return Array<Any>(ret.begin(), ret.end()); } else { return src; } } TVMRetValue ConvertArgToDevice(AnyView input, Device dev, Allocator* alloc) { // in terms of memory-behavior. // To be extra careful, we copy DLTensor. // The developer can still explicitly allocate NDArray // in TVM Native API or NDArray::FromDLPack to regain zero copy behavior. Any ret; if (auto opt_obj = input.as<ObjectRef>()) { ret = ConvertObjectToDevice(opt_obj.value(), dev, alloc); } else if (auto opt_dltensor = input.as<DLTensor*>()) { DLTensor* tensor = opt_dltensor.value(); std::vector<int64_t> shape(tensor->shape, tensor->shape + tensor->ndim); auto dst = alloc->Empty(shape, tensor->dtype, dev); dst.CopyFrom(tensor); ret = dst; } else { ret = input; } return ret; } TVMRetValue ConvertRegToDevice(TVMRetValue input, Device dev, Allocator* alloc) { Any ret; if (auto opt_obj = input.as<ObjectRef>()) { ret = ConvertObjectToDevice(opt_obj.value(), dev, alloc); } else { ret = input; } return ret; } //----------------------------------------------------------- // VM implementations. //----------------------------------------------------------- /*! * \brief The register type. */ using RegType = TVMRetValue; /*! * \brief A representation of a stack frame. * * A stack frame is a record containing the information needed * to restore the caller's virtual machine state after returning * from a function call. */ struct VMFrame { /*! \brief The return program counter. */ Index return_pc; /*! \brief Statically allocated space for objects */ std::vector<RegType> register_file; /*! \brief Register in caller's frame to put return value */ RegName caller_return_register; // The following fields are used for PackedFunc call within // a single function scope. The space is reused across multiple // packed func calls to increase cache locality and avoid re-allocation /*! \brief Temporary argument value stack for packed func call. */ std::vector<TVMValue> call_arg_values; /*! \brief Temporary argument tcode stack for packed func call. */ std::vector<int> call_arg_tcodes; std::vector<AnyView> call_args; VMFrame(Index pc, Index register_file_size) : return_pc(pc), register_file(register_file_size), caller_return_register(0) {} void Clear() { this->caller_return_register = 0; this->call_args.clear(); for (RegType& reg : register_file) { reg = nullptr; } } void ResetForRecycle(Index pc, Index register_file_size) { this->return_pc = pc; this->register_file.resize(register_file_size); } }; class VirtualMachineImpl : public VirtualMachine { public: //--------------------------------------------------- // Public facing functions overloading //--------------------------------------------------- void LoadExecutable(ObjectPtr<VMExecutable> exec) final; void Init(const std::vector<Device>& devices, const std::vector<AllocatorType>& alloc_types) final; VMClosure GetClosure(const String& func_name) final { return this->GetClosureInternal(func_name, false).value(); } void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args, TVMRetValue* rv) final; void SetInstrument(PackedFunc instrument) final { this->instrument_ = instrument; } //--------------------------------------------------- // Functions in the vtable of Module //--------------------------------------------------- void _Init(TVMArgs args, TVMRetValue* rv); void _SaveClosure(TVMArgs args, TVMRetValue* rv); void _InvokeClosure(TVMArgs args, TVMRetValue* rv); void _InvokeClosureStateful(std::string func_name); void _SetInstrument(TVMArgs args, TVMRetValue* rv); void _GetOutputArity(TVMArgs args, TVMRetValue* rv); void _GetOutput(TVMArgs args, TVMRetValue* rv); void _SetInputWithoutParamModule(TVMArgs args, TVMRetValue* rv); void _SetInputWithParamModule(TVMArgs args, TVMRetValue* rv); int _GetFunctionArity(std::string func_name); std::string _GetFunctionParamName(std::string func_name, int index); PackedFunc _LookupFunction(const String& name); TVM_MODULE_VTABLE_BEGIN("relax.VirtualMachine"); TVM_MODULE_VTABLE_ENTRY_PACKED("vm_initialization", &VirtualMachineImpl::_Init); TVM_MODULE_VTABLE_ENTRY_PACKED("save_function", &VirtualMachineImpl::_SaveClosure); TVM_MODULE_VTABLE_ENTRY_PACKED("invoke_closure", &VirtualMachineImpl::_InvokeClosure); TVM_MODULE_VTABLE_ENTRY("invoke_stateful", &VirtualMachineImpl::_InvokeClosureStateful); TVM_MODULE_VTABLE_ENTRY_PACKED("set_instrument", &VirtualMachineImpl::_SetInstrument); TVM_MODULE_VTABLE_ENTRY_PACKED("get_output_arity", &VirtualMachineImpl::_GetOutputArity); TVM_MODULE_VTABLE_ENTRY_PACKED("get_output", &VirtualMachineImpl::_GetOutput); TVM_MODULE_VTABLE_ENTRY_PACKED("set_input", &VirtualMachineImpl::_SetInputWithoutParamModule); TVM_MODULE_VTABLE_ENTRY_PACKED("set_input_with_param_module", &VirtualMachineImpl::_SetInputWithParamModule); TVM_MODULE_VTABLE_ENTRY("get_function_arity", &VirtualMachineImpl::_GetFunctionArity); TVM_MODULE_VTABLE_ENTRY("get_function_param_name", &VirtualMachineImpl::_GetFunctionParamName); TVM_MODULE_VTABLE_END_WITH_DEFAULT(&VirtualMachineImpl::_LookupFunction); //-------------------------------------------------- // Additional support arguments functions for VM //-------------------------------------------------- /*! * \brief Internal implementation of GetClosure which also allow none. * \param func_name The name of the function. * \param allow_missing Whether none is allowed. * \return The result */ Optional<VMClosure> GetClosureInternal(const String& func_name, bool allow_missing); /*! * \brief Set inputs to a function. * \param func_name The function name. * \param args args[offset:] are arguments to the function. If the arguments are not of the * correct device for the function, they will be copied to the device. * \param with_param_module If set to true, the last argument will be a module and can be invoked * to get the argument, this is mainly used for debugging purposes and setting composite * objects. \note This interface works when using VM over RPC by internally converting NDArray in * the arguments to DLTensor, which is supported in RPC where remote could only have a minimal C * runtime. */ void SetInput(std::string func_name, bool with_param_module, TVMArgs args); /*! * \brief Look up whether the VM has a function by the given name. * \param func_name the function's name * \return The function, if it exists. Logs a fatal error if not. */ VMFuncInfo LookupVMFuncInfo(const std::string& func_name); /*! * \brief Look up whether the VM has outputs for the given function. * \param func_name the function's name * \return The output, if it exists. Logs a fatal error if not. */ RegType LookupVMOutput(const std::string& func_name); /*! * \brief Fully bind the argument of a global function and save it in the env. * \param func_name The global function name to be saved. * \param save_name The saved name of the function. * \param include_return Whether forward the return value, set it to false allows * us to ignore forwarding return value, which can be helpful to do benchmarking * in RPC environment when return value is complicated Array. * * \param args The arguments to bound to the function. * \note This function is used by RPC server to help benchmarking. */ void SaveClosure(const String& func_name, const String& save_name, bool include_return, TVMArgs args); /*! * \brief Internal function to invoke a closure. * \param closure_or_packed The closure to be invoked. * \param args The arguments to the function. * \return The result value. */ RegType InvokeClosureInternal(const ObjectRef& closure_or_packed, const std::vector<RegType>& args); /*! * \brief Invoke a VM function by interpreting bytecode. * \param fidx The function index. * \param args The arguments to the function. * \return The object representing the result. */ RegType InvokeBytecode(Index fidx, const std::vector<RegType>& args); protected: /*! * \brief Get function by querying all of the current module's imports. * \param name The name of the function. * \return The result function, can return PackedFunc(nullptr) if nothing is found. */ PackedFunc GetFuncFromImports(const String& name) { for (auto& lib : this->imports_) { PackedFunc func = lib->GetFunction(name, true); if (func.defined()) return func; } return PackedFunc(nullptr); } /*! * \brief Initialize function pool. */ void InitFuncPool(); /*! * \brief A RAII wrapper that pushes and pops VM frames. */ class FrameGuard { public: VirtualMachineImpl* vm; explicit FrameGuard(VirtualMachineImpl* vm, std::unique_ptr<VMFrame> frame) : vm(vm) { vm->frames_.emplace_back(std::move(frame)); } ~FrameGuard() { ICHECK_GT(vm->frames_.size(), 0); vm->pc_ = vm->frames_.back()->return_pc; vm->frames_.back()->Clear(); vm->frame_free_list_.emplace_back(std::move(vm->frames_.back())); vm->frames_.pop_back(); } }; //------------------------------------------------- // Instruction interpretations. //------------------------------------------------- /*! * \brief Push a call frame onto the call stack. * \param ret_pc The program counter to return to. * \param vm_func The function to be pushed to the call stack. * \return A RAII wrapper that pops the frame when going out of scope. */ FrameGuard PushFrame(Index ret_pc, const VMFuncInfo& vm_func) { std::unique_ptr<VMFrame> new_frame; if (!frame_free_list_.empty()) { new_frame = std::move(frame_free_list_.back()); frame_free_list_.pop_back(); new_frame->ResetForRecycle(ret_pc, vm_func.register_file_size); } else { new_frame = std::make_unique<VMFrame>(ret_pc, vm_func.register_file_size); } return FrameGuard(this, std::move(new_frame)); } /*! * \brief Write to a VM register. * \param frame current vm frame. * \param reg The register to write to. * \param obj The object to write to. */ TVM_ALWAYS_INLINE void WriteRegister(VMFrame* frame, RegName reg, const RegType& obj) { ICHECK_LT(reg, frame->register_file.size()); frame->register_file[reg] = obj; } /*! * \brief Read a VM register. * \param frame current vm frame. * \param reg The register to read from. * \return The value of the register. */ TVM_ALWAYS_INLINE RegType ReadRegister(VMFrame* frame, RegName reg) { if (reg < Instruction::kBeginSpecialReg) { return frame->register_file[reg]; } RegType ret; if (reg == Instruction::kVoidRegister) { ret = nullptr; } else { ICHECK_EQ(reg, Instruction::kVMRegister); // per convention, ctx ptr must be VirtualMachine* casted to void. // this and VirtualMachine* may or may not be the same // do first cast to VirtualMachine* then to void* ret = static_cast<void*>(static_cast<VirtualMachine*>(this)); } return ret; } /*! * \brief Run call instruction. * \param curr_frame The current frame. * \param inst The call instruction. */ virtual void RunInstrCall(VMFrame* curr_frame, Instruction inst); /*! \brief Run VM dispatch loop. */ void RunLoop(); /*! * \brief Retrieve the name of the function identified by the given index. * \param idx The index into the VM executable function table. * \return The name of the function. */ const std::string& GetFuncName(int idx) { return exec_->func_table[idx].name; } /*! * \brief Retrieve the inputs for a function. * \param func_name The name of the function. * \return The function inputs. */ const std::vector<RegType>& GetInputsFor(const std::string& func_name) { return inputs_[func_name]; } void ClearInputsFor(const std::string& func_name) { inputs_.erase(func_name); } //-------------------------------------------------------- // Internal states for execution. //-------------------------------------------------------- /*! \brief The loaded executable. */ ObjectPtr<VMExecutable> exec_; /*! \brief The global constant pool */ std::vector<TVMRetValue> const_pool_; /*! * \brief Function pool to cache functions in func_table */ std::vector<TVMRetValue> func_pool_; //-------------------------------------------------------- // Executor interface support //-------------------------------------------------------- /*! \brief The function name to input register mapping. */ std::unordered_map<std::string, std::vector<RegType>> inputs_; /*! \brief The function name to output register. */ std::unordered_map<std::string, RegType> outputs_; /*! \brief A store of closures created by `save_function`. */ std::unordered_map<std::string, VMClosure> saved_closures_; //------------------------------------------------------------ // VM Instruction execution. //------------------------------------------------------------ /*! * \brief The current stack of call frames. * \note: Use unique ptr to avoid re-allocation and copy when frames_ get resized. */ std::vector<std::unique_ptr<VMFrame>> frames_; /*! * \brief A free list of frame */ std::vector<std::unique_ptr<VMFrame>> frame_free_list_; /*! \brief The virtual machine PC. */ Index pc_{0}; /*! \brief The special return register. */ RegType return_value_; /*!\ brief instrument function. */ PackedFunc instrument_ = nullptr; }; void VirtualMachineImpl::LoadExecutable(ObjectPtr<VMExecutable> exec) { this->exec_ = exec; this->imports_ = exec_->imports(); } void VirtualMachineImpl::Init(const std::vector<Device>& devices, const std::vector<AllocatorType>& alloc_types) { ICHECK_EQ(devices.size(), alloc_types.size()); this->devices.reserve(devices.size()); this->allocators.reserve(alloc_types.size()); for (size_t i = 0; i < devices.size(); i++) { auto alloc = MemoryManager::GetOrCreateAllocator(devices[i], alloc_types[i]); this->devices.push_back(devices[i]); this->allocators.push_back(alloc); } // Setup constant sections. this->const_pool_.reserve(exec_->constants.size()); for (const auto& constant : exec_->constants) { if (auto opt_nd = constant.as<NDArray>()) { this->const_pool_.push_back(ConvertRegToDevice(opt_nd.value(), devices[0], allocators[0])); } else { this->const_pool_.push_back(constant); } } // Setup function sections. this->InitFuncPool(); } VMFuncInfo VirtualMachineImpl::LookupVMFuncInfo(const std::string& func_name) { ICHECK(exec_) << "The executable is not created yet."; auto it = this->exec_->func_map.find(func_name); CHECK(it != this->exec_->func_map.end()) << "ValueError: Unknown function: " << func_name; return exec_->func_table[it->second]; } RegType VirtualMachineImpl::LookupVMOutput(const std::string& func_name) { if (!outputs_.count(func_name)) { LOG(FATAL) << "ValueError: No output saved for call of \"" << func_name << "\"; use `invoke_stateful` to call it first."; } return outputs_[func_name]; } void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module, TVMArgs args) { const auto& m = exec_->func_map; if (m.find(func_name) != m.end()) { Index gf_idx = m.at(func_name); const VMFuncInfo& vm_func = exec_->func_table[gf_idx]; size_t params_num = vm_func.num_args; ICHECK_EQ(args.size(), params_num) << "The number of provided parameters doesn't match the number of arguments for"; std::vector<RegType> func_args(params_num); for (int i = 0; i < args.size(); ++i) { if (with_param_module && i == args.size() - 1) { // call param func to get the arguments(usually corresponds to param pack.) func_args[i] = (args[i].cast<Module>()).GetFunction("get_params")(); } else { func_args[i] = ConvertArgToDevice(args[i], devices[0], allocators[0]); } } inputs_[func_name] = func_args; } else { LOG(FATAL) << "ValueError: Unknown function: " << func_name; } } //------------------------------------------ // Closure handling //------------------------------------------ void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args, TVMRetValue* rv) { // run packed call if it is a packed func. if (auto* packed = closure_or_packedfunc.as<PackedFunc::ContainerType>()) { packed->CallPacked(args.data(), args.size(), rv); return; } // run closure call. auto* clo = closure_or_packedfunc.as<VMClosureObj>(); ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc "; std::vector<AnyView> packed_args(args.size() + 1); // per convention, ctx ptr must be VirtualMachine* casted to void. // this and VirtualMachine* may or maynot be the same // do first cast to VirtualMachine* then to void* packed_args[0] = static_cast<void*>(static_cast<VirtualMachine*>(this)); std::copy(args.data(), args.data() + args.size(), packed_args.begin() + 1); { NVTXScopedRange scope("RelaxVM: " + clo->func_name); clo->impl.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), rv); } } // internal variant version of invoke closurepacked RegType VirtualMachineImpl::InvokeClosureInternal(const ObjectRef& closure_or_packed, const std::vector<RegType>& args) { RegType ret; auto* packed = closure_or_packed.as<PackedFunc::ContainerType>(); auto* clo = closure_or_packed.as<VMClosureObj>(); int clo_offset = clo != nullptr ? 1 : 0; std::vector<AnyView> packed_args(args.size() + clo_offset); if (clo != nullptr) { packed_args[0] = static_cast<void*>(static_cast<VirtualMachine*>(this)); } for (size_t i = 0; i < args.size(); ++i) { packed_args[i + clo_offset] = args[i]; } if (packed != nullptr) { packed->CallPacked(packed_args.data(), packed_args.size(), &ret); } else { ICHECK(clo != nullptr); clo->impl.CallPacked(packed_args.data(), packed_args.size(), &ret); } return ret; } void VirtualMachineImpl::SaveClosure(const String& func_name, const String& save_name, bool include_return, TVMArgs args) { VMClosure clo = this->GetClosure(func_name); std::vector<RegType> inputs(args.size()); for (int i = 0; i < args.size(); ++i) { inputs[i] = ConvertArgToDevice(args[i], this->devices[0], this->allocators[0]); } PackedFunc impl = VMClosure::BindLastArgs(clo->impl, inputs); if (!include_return) { impl = PackedFunc([impl](TVMArgs args, TVMRetValue* rv) { TVMRetValue temp; impl.CallPacked(args, &temp); }); } saved_closures_[save_name] = VMClosure(save_name, impl); } Optional<VMClosure> VirtualMachineImpl::GetClosureInternal(const String& func_name, bool allow_missing) { // look up saved closures. auto saved_it = saved_closures_.find(func_name); if (saved_it != saved_closures_.end()) { return saved_it->second; } auto it = exec_->func_map.find(func_name); if (it == exec_->func_map.end()) { if (allow_missing) return NullOpt; LOG(FATAL) << "ValueError: Unknown function: " << func_name; } Index gf_idx = it->second; const VMFuncInfo& finfo = exec_->func_table[gf_idx]; if (finfo.kind == VMFuncInfo::FuncKind::kVMFunc) { // NOTE: should not capture strong ref to self and avoid cyclic ref. auto impl = PackedFunc([gf_idx](TVMArgs args, TVMRetValue* rv) { // Per convention, ctx ptr is a VirtualMachine* VirtualMachine* ctx_ptr = static_cast<VirtualMachine*>(args[0].cast<void*>()); std::vector<RegType> inputs(args.size() - 1); for (size_t i = 0; i < inputs.size(); ++i) { inputs[i] = args[i + 1]; } *rv = static_cast<VirtualMachineImpl*>(ctx_ptr)->InvokeBytecode(gf_idx, inputs); }); return VMClosure(func_name, impl); } else { ICHECK(finfo.kind == VMFuncInfo::FuncKind::kVMTIRFunc) << "Cannot support closure with function kind " << static_cast<int>(finfo.kind); PackedFunc tir_func = GetFuncFromImports("__vmtir__" + finfo.name); ICHECK(tir_func != nullptr) << "Cannot find underlying compiled tir function of VMTIRFunc " << finfo.name; auto impl = PackedFunc([this, finfo, tir_func](TVMArgs args, TVMRetValue* rv) { // Per convention, ctx ptr is a VirtualMachine* VirtualMachine* ctx_ptr = static_cast<VirtualMachine*>(args[0].cast<void*>()); ICHECK(ctx_ptr == this); ICHECK_EQ(args.size() - 1, finfo.num_args) << "Function " << finfo.name << " expects " << finfo.num_args << " arguments"; ICHECK_GE(finfo.register_file_size, finfo.num_args + 1); std::vector<TVMRetValue> reg_file(finfo.register_file_size); for (int64_t i = 0; i < finfo.num_args; ++i) { reg_file[i] = args[i + 1]; } void* reg_anylist_handle = reg_file.data(); void* const_anylist_handle = this->const_pool_.data(); void* func_anylist_handle = this->func_pool_.data(); tir_func(static_cast<void*>(ctx_ptr), reg_anylist_handle, const_anylist_handle, func_anylist_handle); // Return value always stored after inputs. *rv = reg_file[finfo.num_args]; }); return VMClosure(func_name, impl); } } //-------------------------------------------------------------------- // Instruction interpretations. //-------------------------------------------------------------------- RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vector<RegType>& args) { const VMFuncInfo& gfunc = exec_->func_table[gf_idx]; ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); // Get the curr instr which might be a potential caller. Instruction curr_instr = exec_->GetInstruction(pc_); auto guard = PushFrame(this->pc_, gfunc); // Get new frame and set the caller info. VMFrame* curr_frame = frames_.back().get(); if (curr_instr.op == Opcode::Call) { curr_frame->caller_return_register = curr_instr.dst; } // load arguments to the register file ICHECK_EQ(static_cast<size_t>(gfunc.num_args), args.size()) << "ValueError: Invoking function " << gfunc.name << " expects " << gfunc.num_args << " arguments" << [&]() { std::stringstream ss; if (gfunc.param_names.size()) { ss << " ("; for (size_t i = 0; i < gfunc.param_names.size(); i++) { if (i) { ss << ", "; } ss << gfunc.param_names[i]; } ss << ")"; } return ss.str(); }() << ", but " << args.size() << " arguments were provided."; for (size_t i = 0; i < args.size(); ++i) { WriteRegister(frames_.back().get(), i, args[i]); } // set program counter pc_ = gfunc.start_instr; RunLoop(); return return_value_; } void VirtualMachineImpl::InitFuncPool() { func_pool_.resize(exec_->func_table.size()); for (size_t func_index = 0; func_index < exec_->func_table.size(); ++func_index) { const VMFuncInfo& info = exec_->func_table[func_index]; if (info.kind == VMFuncInfo::FuncKind::kPackedFunc) { // only look through imports first PackedFunc func = GetFuncFromImports(info.name); if (!func.defined()) { const auto p_func = tvm::ffi::Function::GetGlobal(info.name); if (p_func.has_value()) func = *(p_func); } ICHECK(func.defined()) << "Error: Cannot find PackedFunc " << info.name << " in either Relax VM kernel library, or in TVM runtime PackedFunc registry, or in " "global Relax functions of the VM executable"; func_pool_[func_index] = func; } else { ICHECK(info.kind == VMFuncInfo::FuncKind::kVMFunc || info.kind == VMFuncInfo::FuncKind::kVMTIRFunc); auto clo = this->GetClosure(info.name); func_pool_[func_index] = clo; } } } void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << GetFuncName(instr.func_idx); int args_begin_offset = instrument_ != nullptr ? 4 : 0; // Use the call arg stack from the current frame to increase reuse // and avoid re-allocation curr_frame->call_args.resize(args_begin_offset + instr.num_args); // NOTE: no changes and resize to those vector ref(otherwise can leads to segfault) // in the remainder part of the function. std::vector<AnyView>& call_args = curr_frame->call_args; for (Index i = 0; i < instr.num_args; ++i) { Instruction::Arg arg = instr.args[i]; int arg_index = args_begin_offset + i; switch (arg.kind()) { case Instruction::ArgKind::kRegister: { call_args[arg_index] = ReadRegister(curr_frame, arg.value()); break; } case Instruction::ArgKind::kImmediate: { call_args[arg_index] = arg.value(); break; } case Instruction::ArgKind::kConstIdx: { call_args[arg_index] = this->const_pool_[arg.value()]; break; } case Instruction::ArgKind::kFuncIdx: { ICHECK_LT(static_cast<size_t>(arg.value()), this->func_pool_.size()); call_args[arg_index] = this->func_pool_[arg.value()]; break; } default: { LOG(FATAL) << "ValueError: Unknown argument kind: " << int(arg.kind()); } } } ffi::PackedArgs args(call_args.data() + args_begin_offset, instr.num_args); TVMRetValue ret; ICHECK_LT(static_cast<size_t>(instr.func_idx), this->func_pool_.size()); if (instrument_ == nullptr) { this->InvokeClosurePacked(func_pool_[instr.func_idx].cast<ObjectRef>(), args, &ret); } else { // insert light-weight instrument callback call_args[0] = func_pool_[instr.func_idx]; call_args[1] = GetFuncName(instr.func_idx); call_args[2] = true; call_args[3] = nullptr; Any rv; // store dtype to str since py callback cannot handle dtype atm. std::vector<std::unique_ptr<std::string>> temp_dtype; for (int i = 0; i < instr.num_args; ++i) { if (call_args[i + args_begin_offset].type_index() == ffi::TypeIndex::kTVMFFIDataType) { std::string str_dtype = DLDataTypeToString(call_args[i + args_begin_offset].cast<DLDataType>()); temp_dtype.emplace_back(std::make_unique<std::string>(str_dtype)); call_args[i + args_begin_offset] = *temp_dtype.back(); } } int ret_kind = static_cast<int>(VMInstrumentReturnKind::kNoOp); instrument_.CallPacked(call_args.data(), call_args.size(), &rv); if (auto opt_int = rv.as<int64_t>()) { ret_kind = opt_int.value(); } if (ret_kind != static_cast<int>(VMInstrumentReturnKind::kSkipRun)) { this->InvokeClosurePacked(func_pool_[instr.func_idx].cast<ObjectRef>(), args, &ret); call_args[2] = false; call_args[3] = ret; instrument_.CallPacked(call_args.data(), call_args.size(), &rv); } } // save the return value to the register // saving to special register is a NOP if (instr.dst < Instruction::kBeginSpecialReg) { WriteRegister(curr_frame, instr.dst, ret); } // increment pc pc_++; } void VirtualMachineImpl::RunLoop() { VMFrame* curr_frame = frames_.back().get(); while (true) { ICHECK_LT(static_cast<size_t>(pc_), exec_->instr_offset.size()) << "run into invalid section"; Instruction instr = exec_->GetInstruction(pc_); switch (instr.op) { case Opcode::Call: { this->RunInstrCall(curr_frame, instr); break; } case Opcode::Ret: { // If we have hit the point from which we started // running, we should return to the caller breaking // the dispatch loop. return_value_ = ReadRegister(curr_frame, instr.result); RegName caller_return_register = curr_frame->caller_return_register; if (frames_.size() <= 1) { // directly return if no other frame in the call stack. } else { // return from a local call. // Update the current frame to be the parent frame. VMFrame* parent_frame = frames_.end()[-2].get(); WriteRegister(parent_frame, caller_return_register, return_value_); } return; } case Opcode::Goto: { pc_ += instr.pc_offset; break; } case Opcode::If: { int64_t cond_val = ReadRegister(curr_frame, instr.cond).cast<int64_t>(); if (cond_val != 0) { pc_++; } else { ICHECK_GT(instr.false_offset, 1); pc_ += instr.false_offset; } break; } } } } ObjectPtr<VirtualMachine> VirtualMachine::Create() { return make_object<VirtualMachineImpl>(); } //-------------------------------------------------------------------- // FFI related code //-------------------------------------------------------------------- void VirtualMachineImpl::_Init(TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size() % 3, 0); std::vector<Device> devices; std::vector<AllocatorType> alloc_types; for (int i = 0; i < args.size(); i += 3) { int device_type = args[i].cast<int>(); int device_id = args[i + 1].cast<int>(); int alloc_type = args[i + 2].cast<int>(); devices.push_back(Device{DLDeviceType(device_type), device_id}); alloc_types.push_back(AllocatorType(alloc_type)); } this->Init(devices, alloc_types); } void VirtualMachineImpl::_SaveClosure(TVMArgs args, TVMRetValue* rv) { ICHECK_GE(args.size(), 3); std::string func_name = args[0].cast<std::string>(); this->SaveClosure(func_name, args[1].cast<String>(), args[2].cast<bool>(), args.Slice(3)); } void VirtualMachineImpl::_InvokeClosure(TVMArgs args, TVMRetValue* rv) { this->InvokeClosurePacked(args[0].cast<ObjectRef>(), args.Slice(1), rv); } void VirtualMachineImpl::_InvokeClosureStateful(std::string func_name) { const std::unordered_map<std::string, Index>& m = this->exec_->func_map; if (m.find(func_name) == m.end()) { LOG(FATAL) << "ValueError: Unknown function: " << func_name; } if (!inputs_.count(func_name)) { LOG(FATAL) << "ValueError: No inputs set for stateful call of " << func_name << "; use `set_input` first."; return; } outputs_[func_name] = this->InvokeClosureInternal(func_pool_[m.at(func_name)].cast<ObjectRef>(), inputs_[func_name]); } void VirtualMachineImpl::_SetInstrument(TVMArgs args, TVMRetValue* rv) { if (args[0].as<ffi::Function>()) { this->SetInstrument(args[0].cast<PackedFunc>()); } else { String func_name = args[0].cast<String>(); const auto factory = tvm::ffi::Function::GetGlobal(func_name); CHECK(factory.has_value()) << "Cannot find factory " << func_name; TVMRetValue rv; factory->CallPacked(args.Slice(1), &rv); this->SetInstrument(rv.cast<PackedFunc>()); } } void VirtualMachineImpl::_GetOutputArity(TVMArgs args, TVMRetValue* rv) { std::string func_name = args[0].cast<std::string>(); RegType out = LookupVMOutput(func_name); Any obj = IndexIntoNestedObject(out, args, 1); if (const auto* arr = obj.as<ffi::ArrayObj>()) { *rv = static_cast<int>(arr->size()); } else { *rv = -1; } } void VirtualMachineImpl::_GetOutput(TVMArgs args, TVMRetValue* rv) { std::string func_name = args[0].cast<std::string>(); RegType out = LookupVMOutput(func_name); Any obj = IndexIntoNestedObject(out, args, 1); if (obj.as<ffi::ArrayObj>()) { LOG(FATAL) << "ValueError: `get_output` cannot return a tuple for RPC compatibility. " "Please specify another index argument."; return; } *rv = obj; } void VirtualMachineImpl::_SetInputWithoutParamModule(TVMArgs args, TVMRetValue* rv) { std::string func_name = args[0].cast<std::string>(); this->SetInput(func_name, false, args.Slice(1)); } void VirtualMachineImpl::_SetInputWithParamModule(TVMArgs args, TVMRetValue* rv) { std::string func_name = args[0].cast<std::string>(); this->SetInput(func_name, true, args.Slice(1)); } int VirtualMachineImpl::_GetFunctionArity(std::string func_name) { const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name); return vm_func.param_names.size(); } std::string VirtualMachineImpl::_GetFunctionParamName(std::string func_name, int index) { const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name); if (static_cast<size_t>(index) >= vm_func.param_names.size()) { LOG(FATAL) << "ValueError: Invalid index for " << func_name << " (" << index << " out of " << vm_func.param_names.size() << ")"; } return vm_func.param_names[index]; } PackedFunc VirtualMachineImpl::_LookupFunction(const String& name) { if (Optional<VMClosure> opt = this->GetClosureInternal(name, true)) { return PackedFunc( [clo = opt.value(), _self = GetRef<Module>(this)](TVMArgs args, TVMRetValue* rv) -> void { auto* self = const_cast<VirtualMachineImpl*>(_self.as<VirtualMachineImpl>()); ICHECK(self); self->InvokeClosurePacked(clo, args, rv); }); } return PackedFunc(nullptr); } //---------------------------------------------------------------- // Profiler can be optionally disabled via a macro to reduce dep. //---------------------------------------------------------------- #if TVM_RELAX_VM_ENABLE_PROFILER /*! * \brief An extension of VirtualMachineImpl to support per-op profiling * It overrides RunInstrCall to add instrumentations around it. */ class VirtualMachineProfiler : public VirtualMachineImpl { public: PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) override { if (name == "profile") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { std::string f_name = args[0].cast<std::string>(); VMClosure clo = this->GetClosure(f_name); std::vector<Device> devices; for (auto dev : this->devices) { if (dev.device_type > 0) { devices.push_back(dev); } } prof_ = profiling::Profiler(devices, {}, {{String("Executor"), String("VM")}}); auto inputs = GetInputsFor(f_name); bool clear_inputs = false; if (inputs.size() == 0) { ICHECK(args.size() > 1) << "No input is provided"; SetInput(f_name, false, args.Slice(1)); inputs = GetInputsFor(f_name); clear_inputs = true; } else { ICHECK_EQ(args.size(), 1) << "Inputs are already provided by set_input."; } // warmup this->InvokeClosureInternal(clo, inputs); prof_->Start(); this->InvokeClosureInternal(clo, inputs); prof_->Stop(); // Return the report as json, since profiling::Report object is not supported by RPC std::string report_json = prof_->Report()->AsJSON(); *rv = report_json; prof_ = std::nullopt; // releases hardware counters if (clear_inputs) { // SetInput modifies the internal states of VM. Undo the change after profiling. ClearInputsFor(f_name); } }); } else { return VirtualMachineImpl::GetFunction(name, sptr_to_self); } } protected: void RunInstrCall(VMFrame* curr_frame, Instruction inst) override { bool profiling = false; if (prof_ && prof_->IsRunning()) { auto f_name = GetFuncName(inst.func_idx); std::optional<Device> dev; std::vector<NDArray> arrs; auto f_check_ndarray_arg = [&dev, &arrs](const RegType& arg) { if (auto opt_nd = arg.as<NDArray>()) { NDArray arr = opt_nd.value(); if (arr.defined()) { dev = arr->device; arrs.push_back(arr); } } }; for (Index i = 0; i < inst.num_args; ++i) { Instruction::Arg arg = inst.args[i]; if (arg.kind() == Instruction::ArgKind::kRegister) { auto reg = ReadRegister(curr_frame, arg.value()); f_check_ndarray_arg(reg); } else if (arg.kind() == Instruction::ArgKind::kConstIdx) { const auto& const_val = this->const_pool_[arg.value()]; f_check_ndarray_arg(const_val); } } std::unordered_map<std::string, ObjectRef> metrics; metrics["Argument Shapes"] = profiling::ShapeString(arrs); // If a suitable device is found, enable profiling. if (dev) { profiling = true; prof_->StartCall(f_name, *dev, metrics); } } VirtualMachineImpl::RunInstrCall(curr_frame, inst); if (profiling) { prof_->StopCall(); } } private: std::optional<profiling::Profiler> prof_; }; ObjectPtr<VirtualMachine> VirtualMachine::CreateProfiler() { return make_object<VirtualMachineProfiler>(); } #else ObjectPtr<VirtualMachine> VirtualMachine::CreateProfiler() { LOG(FATAL) << "Profiler support is disabled"; return nullptr; } #endif // TVM_RELAX_VM_ENABLE_PROFILER } // namespace relax_vm } // namespace runtime } // namespace tvm