include/tvm/ir/module.h (94 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/module.h
* \brief IRModule that holds the functions and type definitions.
*/
#ifndef TVM_IR_MODULE_H_
#define TVM_IR_MODULE_H_
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/global_info.h>
#include <tvm/ir/source_map.h>
#include <tvm/ir/type.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/string.h>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace tvm {
class IRModule;
/*!
* \brief IRModule that holds functions and type definitions.
*
* IRModule is the basic unit for all IR transformations across the stack.
*
* Many operations require access to the global IRModule.
* We pass the IRModule by value in a functional style as an explicit argument,
* but we mutate the Module while optimizing programs.
* \sa IRModule
*/
class IRModuleNode : public Object {
public:
/*! \brief A map from ids to all global functions. */
Map<GlobalVar, BaseFunc> functions;
/*! \brief The source map for the module. */
SourceMap source_map;
/* \brief Additional attributes storing meta-data about the module. */
DictAttrs attrs;
/*! \brief Globally static object that are referred by the IR itself */
Map<String, Array<GlobalInfo>> global_infos;
/*!
* \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Map<String, GlobalVar> global_var_map_;
/*!
* \brief Get a module attribute.
*
* \param attr_key The attribute key.
* \param default_value The default value if the key does not exist, defaults to nullptr.
*
* \return The result
*
* \tparam TOBjectRef the expected object type.
* \throw Error if the key exists but the value does not match TObjectRef
*
* \code
*
* void GetAttrExample(const IRModule& mod) {
* auto value = f->GetAttr<Integer>("AttrKey", 0);
* }
*
* \endcode
*/
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(std::nullopt)) const {
return attrs.GetAttr(attr_key, default_value);
}
// variant that uses TObjectRef to enable implicit conversion to default value.
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}
/*!
* \brief Get the metadata attributes.
* \returns The additional meta-data attributes
*/
DictAttrs GetAttrs() const { return attrs; }
/*!
* \brief Check whether the module has an non-zero integer attr.
*
* This function can be used to check whether an optional
* attribute mark(e.g. inline) exists.
*
* \param attr_key The key to the attribute.
* \return The check result.
*
* \code
*
* void HasNonzeroAttrExample(const IRModule& mod) {
* if (mod->HasNonzeroAttr(attr::kInline)) {
* // inline the function.
* }
* }
*
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); }
IRModuleNode() : source_map() {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("functions", &functions);
v->Visit("global_var_map_", &global_var_map_);
v->Visit("source_map", &source_map);
v->Visit("attrs", &attrs);
v->Visit("global_infos", &global_infos);
}
TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
TVM_DLL void SHashReduce(SHashReducer hash_reduce) const;
/*!
* \brief Add a function to the global environment.
* \param var The var of the global function.
* \param func The function.
* \param update Controls whether you can replace a definition in the
* environment.
*/
TVM_DLL void Add(const GlobalVar& var, const BaseFunc& func, bool update = false);
/*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
*
* It does not do type inference as Add does.
*/
TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func);
/*!
* \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
TVM_DLL void Update(const GlobalVar& var, const BaseFunc& func);
/*!
* \brief Update an array of global infos in the global environment.
* \param name The name of the global info.
* \param info The new array of global infos.
*/
TVM_DLL void UpdateGlobalInfo(const String& name, const Array<GlobalInfo>& info);
/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
*/
TVM_DLL void Remove(const GlobalVar& var);
/*!
* \brief Check if the global_var_map_ contains a global variable.
* \param name The variable name.
* \returns true if contains, otherise false.
*/
TVM_DLL bool ContainGlobalVar(const String& name) const;
/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
TVM_DLL GlobalVar GetGlobalVar(const String& str) const;
/*!
* \brief Collect all global vars defined in this module, ordered by
* the global variable name.
* \returns An array of global vars
*/
TVM_DLL Array<GlobalVar> GetGlobalVars() const;
/*!
* \brief Look up a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;
/*!
* \brief Look up a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
TVM_DLL BaseFunc Lookup(const String& name) const;
/*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
*/
TVM_DLL void Update(const IRModule& other);
/*!
* \brief Create a shallow copy of this IRModule.
* \returns The shallow copy of the IRModule.
*/
TVM_DLL IRModule ShallowCopy();
/*!
* \brief The set of imported files.
*/
TVM_DLL std::unordered_set<String> Imports() const;
TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
static constexpr const char* _type_key = "IRModule";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
private:
friend class IRModule;
};
/*!
* \brief Managed reference class to IRModuleNode.
* \sa IRModuleNode
*/
class IRModule : public ObjectRef {
public:
/*!
* \brief constructor
* \param functions Functions in the module.
* \param map The module source map.
* \param attrs The module meta-data attributes.
* \param global_infos Global infos in the module.
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions, SourceMap map = {},
DictAttrs attrs = DictAttrs(),
Map<String, Array<GlobalInfo>> global_infos = {});
/*! \brief default constructor */
IRModule() : IRModule(Map<GlobalVar, BaseFunc>({})) {}
/*!
* \brief constructor
* \param n The object pointer.
*/
explicit IRModule(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return mutable pointers to the node. */
IRModuleNode* operator->() const {
auto* ptr = get_mutable();
ICHECK(ptr != nullptr);
return static_cast<IRModuleNode*>(ptr);
}
/*!
* \brief As for \p FromExprInContext, but assuming \p expr is bound to 'main' and no
* imports.
*/
TVM_DLL static IRModule FromExpr(const RelaxExpr& expr,
const Map<GlobalVar, BaseFunc>& global_funcs = {});
/*!
* \brief Create a shallow copy of an IRModule.
* \param mod The module to copy.
* \return The copied module.
*/
IRModule ShallowCopyIRModule(IRModule mod);
/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;
// allow copy on write.
TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode);
};
namespace attr {
// Following are attributes for IRModule only.
/*!
* \brief Name of the module
*
* Type: String
*/
constexpr const char* kModuleName = "mod_name";
/*
* \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The
* node will record the index into this array. See also kConstNameToConstant below, which is
* the analog for Realy Functions.
*
* Type: Array<runtime::NDArray>
*/
constexpr const char* kConstants = "constants";
/*!
* \brief All the runtime::Modules accumulated during compilation by external codegen. These
* modules must be either directly linked or captured in the final compilation artifact.
*
* Type: Array<runtime::Module>
*/
constexpr const char* kExternalMods = "external_mods";
/*!
* \brief A prefix for generating C symbols system lib creation.
*
* This prefix guides passes that creates global_symbol for internal functions
* that may have c linkage (e.g. TIR functions and some BYOC functions). It also affects
* the symbol of the fat bin blob during module export.
*
* This attribute is used to avoid symbol conflict when we
* generate and combine multiple system libs that get linked into one.
*
* Rationale: mechanisms like BYOC rely on the common global symbol
* and each external compiler also has its own mechanism of mangling.
* As a result, we cannot rely on other mechanisms on setting a global_symbol and then renaming,
* because the external compiler already agreed on the name.
*
* system_lib_prefix provides a way to hint at the passes to allow names to
* avoid name conflict at the beginning.
*
* Note that users can still directly specify global symbols that may conflict.
* It is up to the downstream toolchain to manage those external-facing functions.
*
* This does not affect non-C linkage functions it is less of an issue because
* they will be embedded into fatbin that in different symbols,
* The system lib loader can pick the right prefix for a given prefix.
*
* Having this attribute implies system lib generation linkage.
*/
constexpr const char* kSystemLibPrefix = "system_lib_prefix";
/*!
* \brief All the named runtime::NDArrays accumulated during compilation by external codegen.
* Generally the associated runtime::Module will indicate it requires bindings for these names,
* and during module initialization these bindings will be recovered from a ConstLoaderModule.
* See also kConstantsArray above, which is the analog for PrimFuncs.
*
* Type: Map<String, runtime::NDArray>
*/
constexpr const char* kConstNameToConstant = "const_name_to_constant";
} // namespace attr
} // namespace tvm
#endif // TVM_IR_MODULE_H_