in onnxruntime/python/onnxruntime_pybind_state.cc [998:1504]
void addObjectMethods(py::module& m, Environment& env, ExecutionProviderRegistrationFn ep_registration_fn) {
py::enum_<GraphOptimizationLevel>(m, "GraphOptimizationLevel")
.value("ORT_DISABLE_ALL", GraphOptimizationLevel::ORT_DISABLE_ALL)
.value("ORT_ENABLE_BASIC", GraphOptimizationLevel::ORT_ENABLE_BASIC)
.value("ORT_ENABLE_EXTENDED", GraphOptimizationLevel::ORT_ENABLE_EXTENDED)
.value("ORT_ENABLE_ALL", GraphOptimizationLevel::ORT_ENABLE_ALL);
py::enum_<ExecutionMode>(m, "ExecutionMode")
.value("ORT_SEQUENTIAL", ExecutionMode::ORT_SEQUENTIAL)
.value("ORT_PARALLEL", ExecutionMode::ORT_PARALLEL);
py::enum_<ExecutionOrder>(m, "ExecutionOrder")
.value("DEFAULT", ExecutionOrder::DEFAULT)
.value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED);
py::enum_<OrtAllocatorType>(m, "OrtAllocatorType")
.value("INVALID", OrtInvalidAllocator)
.value("ORT_DEVICE_ALLOCATOR", OrtDeviceAllocator)
.value("ORT_ARENA_ALLOCATOR", OrtArenaAllocator);
py::enum_<OrtMemType>(m, "OrtMemType")
.value("CPU_INPUT", OrtMemTypeCPUInput)
.value("CPU_OUTPUT", OrtMemTypeCPUOutput)
.value("CPU", OrtMemTypeCPU)
.value("DEFAULT", OrtMemTypeDefault);
py::class_<OrtDevice> device(m, "OrtDevice", R"pbdoc(ONNXRuntime device informaion.)pbdoc");
device.def(py::init<OrtDevice::DeviceType, OrtDevice::MemoryType, OrtDevice::DeviceId>())
.def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc")
.def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc")
.def_static("cpu", []() { return OrtDevice::CPU; })
.def_static("cuda", []() { return OrtDevice::GPU; })
.def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; });
py::class_<OrtArenaCfg> ort_arena_cfg_binding(m, "OrtArenaCfg");
// There is a global var: arena_extend_strategy, which means we can't use that var name here
// See docs/C_API.md for details on what the following parameters mean and how to choose these values
ort_arena_cfg_binding.def(py::init([](size_t max_mem, int arena_extend_strategy_local,
int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
auto ort_arena_cfg = std::make_unique<OrtArenaCfg>();
ort_arena_cfg->max_mem = max_mem;
ort_arena_cfg->arena_extend_strategy = arena_extend_strategy_local;
ort_arena_cfg->initial_chunk_size_bytes = initial_chunk_size_bytes;
ort_arena_cfg->max_dead_bytes_per_chunk = max_dead_bytes_per_chunk;
return ort_arena_cfg;
}));
py::class_<OrtMemoryInfo> ort_memory_info_binding(m, "OrtMemoryInfo");
ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
if (strcmp(name, onnxruntime::CPU) == 0) {
return std::make_unique<OrtMemoryInfo>(onnxruntime::CPU, type, OrtDevice(), id, mem_type);
} else if (strcmp(name, onnxruntime::CUDA) == 0) {
return std::make_unique<OrtMemoryInfo>(
onnxruntime::CUDA, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id)), id,
mem_type);
} else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) {
return std::make_unique<OrtMemoryInfo>(
onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id)),
id, mem_type);
} else {
throw std::runtime_error("Specified device is not supported.");
}
}));
py::class_<PySessionOptions>
sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc");
sess
.def(py::init())
.def_readwrite("enable_cpu_mem_arena", &PySessionOptions::enable_cpu_mem_arena,
R"pbdoc(Enables the memory arena on CPU. Arena may pre-allocate memory for future usage.
Set this option to false if you don't want it. Default is True.)pbdoc")
.def_readwrite("enable_profiling", &PySessionOptions::enable_profiling,
R"pbdoc(Enable profiling for this session. Default is false.)pbdoc")
.def_readwrite("profile_file_prefix", &PySessionOptions::profile_file_prefix,
R"pbdoc(The prefix of the profile file. The current time will be appended to the file name.)pbdoc")
.def_readwrite("optimized_model_filepath", &PySessionOptions::optimized_model_filepath,
R"pbdoc(
File path to serialize optimized model to.
Optimized model is not serialized unless optimized_model_filepath is set.
Serialized model format will default to ONNX unless:
- add_session_config_entry is used to set 'session.save_model_format' to 'ORT', or
- there is no 'session.save_model_format' config entry and optimized_model_filepath ends in '.ort' (case insensitive)
)pbdoc")
.def_readwrite("enable_mem_pattern", &PySessionOptions::enable_mem_pattern,
R"pbdoc(Enable the memory pattern optimization. Default is true.)pbdoc")
.def_readwrite("enable_mem_reuse", &PySessionOptions::enable_mem_reuse,
R"pbdoc(Enable the memory reuse optimization. Default is true.)pbdoc")
.def_readwrite("logid", &PySessionOptions::session_logid,
R"pbdoc(Logger id to use for session output.)pbdoc")
.def_readwrite("log_severity_level", &PySessionOptions::session_log_severity_level,
R"pbdoc(Log severity level. Applies to session load, initialization, etc.
0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.)pbdoc")
.def_readwrite("log_verbosity_level", &PySessionOptions::session_log_verbosity_level,
R"pbdoc(VLOG level if DEBUG build and session_log_severity_level is 0.
Applies to session load, initialization, etc. Default is 0.)pbdoc")
.def_property(
"intra_op_num_threads",
[](const PySessionOptions* options) -> int { return options->intra_op_param.thread_pool_size; },
[](PySessionOptions* options, int value) -> void { options->intra_op_param.thread_pool_size = value; },
R"pbdoc(Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose.)pbdoc")
.def_property(
"inter_op_num_threads",
[](const PySessionOptions* options) -> int { return options->inter_op_param.thread_pool_size; },
[](PySessionOptions* options, int value) -> void { options->inter_op_param.thread_pool_size = value; },
R"pbdoc(Sets the number of threads used to parallelize the execution of the graph (across nodes). Default is 0 to let onnxruntime choose.)pbdoc")
.def_readwrite("execution_mode", &PySessionOptions::execution_mode,
R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc")
.def_readwrite("execution_order", &PySessionOptions::execution_order,
R"pbdoc(Sets the execution order. Default is basic topological order.)pbdoc")
.def_property(
"graph_optimization_level",
[](const PySessionOptions* options) -> GraphOptimizationLevel {
GraphOptimizationLevel retval = ORT_ENABLE_ALL;
switch (options->graph_optimization_level) {
case onnxruntime::TransformerLevel::Default:
retval = ORT_DISABLE_ALL;
break;
case onnxruntime::TransformerLevel::Level1:
retval = ORT_ENABLE_BASIC;
break;
case onnxruntime::TransformerLevel::Level2:
retval = ORT_ENABLE_EXTENDED;
break;
case onnxruntime::TransformerLevel::Level3:
retval = ORT_ENABLE_ALL;
break;
default:
retval = ORT_ENABLE_ALL;
LOGS_DEFAULT(WARNING) << "Got invalid graph optimization level; defaulting to ORT_ENABLE_ALL";
break;
}
return retval;
},
[](PySessionOptions* options, GraphOptimizationLevel level) -> void {
switch (level) {
case ORT_DISABLE_ALL:
options->graph_optimization_level = onnxruntime::TransformerLevel::Default;
break;
case ORT_ENABLE_BASIC:
options->graph_optimization_level = onnxruntime::TransformerLevel::Level1;
break;
case ORT_ENABLE_EXTENDED:
options->graph_optimization_level = onnxruntime::TransformerLevel::Level2;
break;
case ORT_ENABLE_ALL:
options->graph_optimization_level = onnxruntime::TransformerLevel::Level3;
break;
}
},
R"pbdoc(Graph optimization level for this session.)pbdoc")
.def_readwrite("use_deterministic_compute", &PySessionOptions::use_deterministic_compute,
R"pbdoc(Whether to use deterministic compute. Default is false.)pbdoc")
.def(
"add_free_dimension_override_by_denotation",
[](PySessionOptions* options, const char* dim_name, int64_t dim_value)
-> void { options->free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{
dim_name,
onnxruntime::FreeDimensionOverrideType::Denotation,
dim_value}); },
R"pbdoc(Specify the dimension size for each denotation associated with an input's free dimension.)pbdoc")
.def(
"add_free_dimension_override_by_name",
[](PySessionOptions* options, const char* dim_name, int64_t dim_value)
-> void { options->free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{
dim_name,
onnxruntime::FreeDimensionOverrideType::Name,
dim_value}); },
R"pbdoc(Specify values of named dimensions within model inputs.)pbdoc")
.def(
"add_session_config_entry",
[](PySessionOptions* options, const char* config_key, const char* config_value) -> void {
// config_key and config_value will be copied
const Status status = options->config_options.AddConfigEntry(config_key, config_value);
if (!status.IsOK())
throw std::runtime_error(status.ErrorMessage());
},
R"pbdoc(Set a single session configuration entry as a pair of strings.)pbdoc")
.def(
"get_session_config_entry",
[](const PySessionOptions* options, const char* config_key) -> std::string {
const std::string key(config_key);
std::string value;
if (!options->config_options.TryGetConfigEntry(key, value))
throw std::runtime_error("SessionOptions does not have configuration with key: " + key);
return value;
},
R"pbdoc(Get a single session configuration value using the given configuration key.)pbdoc")
.def(
"register_custom_ops_library",
[](PySessionOptions* options, const char* library_path) -> void {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
// We need to pass in an `OrtSessionOptions` instance because the exported method in the shared library expects that
// Once we have access to the `OrtCustomOpDomains` within the passed in `OrtSessionOptions` instance, we place it
// into the container we are maintaining for that very purpose and the `ortSessionoptions` instance can go out of scope.
OrtSessionOptions s;
options->custom_op_libraries_.emplace_back(std::make_shared<CustomOpLibrary>(library_path, s));
// reserve enough memory to hold current contents and the new incoming contents
options->custom_op_domains_.reserve(options->custom_op_domains_.size() + s.custom_op_domains_.size());
for (size_t i = 0; i < s.custom_op_domains_.size(); ++i) {
options->custom_op_domains_.emplace_back(s.custom_op_domains_[i]);
}
#else
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(library_path);
ORT_THROW("Custom Ops are not supported in this build.");
#endif
},
R"pbdoc(Specify the path to the shared library containing the custom op kernels required to run a model.)pbdoc")
.def(
"add_initializer", [](PySessionOptions* options, const char* name, py::object& ml_value_pyobject) -> void {
ORT_ENFORCE(strcmp(Py_TYPE(ml_value_pyobject.ptr())->tp_name, PYTHON_ORTVALUE_OBJECT_NAME) == 0, "The provided Python object must be an OrtValue");
// The user needs to ensure that the python OrtValue being provided as an overriding initializer
// is not destructed as long as any session that uses the provided OrtValue initializer is still in scope
// This is no different than the native APIs
const OrtValue* ml_value = ml_value_pyobject.attr(PYTHON_ORTVALUE_NATIVE_OBJECT_ATTR).cast<OrtValue*>();
ORT_THROW_IF_ERROR(options->AddInitializer(name, ml_value));
});
py::class_<RunOptions>(m, "RunOptions", R"pbdoc(Configuration information for a single Run.)pbdoc")
.def(py::init())
.def_readwrite("log_severity_level", &RunOptions::run_log_severity_level,
R"pbdoc(Log severity level for a particular Run() invocation. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.)pbdoc")
.def_readwrite("log_verbosity_level", &RunOptions::run_log_verbosity_level,
R"pbdoc(VLOG level if DEBUG build and run_log_severity_level is 0.
Applies to a particular Run() invocation. Default is 0.)pbdoc")
.def_readwrite("logid", &RunOptions::run_tag,
"To identify logs generated by a particular Run() invocation.")
.def_readwrite("terminate", &RunOptions::terminate,
R"pbdoc(Set to True to terminate any currently executing calls that are using this
RunOptions instance. The individual calls will exit gracefully and return an error status.)pbdoc")
#ifdef ENABLE_TRAINING
.def_readwrite("training_mode", &RunOptions::training_mode,
R"pbdoc(Choose to run in training or inferencing mode)pbdoc")
#endif
.def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches,
R"pbdoc(Only execute the nodes needed by fetch list)pbdoc");
py::class_<ModelMetadata>(m, "ModelMetadata", R"pbdoc(Pre-defined and custom metadata about the model.
It is usually used to identify the model used to run the prediction and
facilitate the comparison.)pbdoc")
.def_readwrite("producer_name", &ModelMetadata::producer_name, "producer name")
.def_readwrite("graph_name", &ModelMetadata::graph_name, "graph name")
.def_readwrite("domain", &ModelMetadata::domain, "ONNX domain")
.def_readwrite("description", &ModelMetadata::description, "description of the model")
.def_readwrite("graph_description", &ModelMetadata::graph_description, "description of the graph hosted in the model")
.def_readwrite("version", &ModelMetadata::version, "version of the model")
.def_readwrite("custom_metadata_map", &ModelMetadata::custom_metadata_map, "additional metadata");
py::class_<onnxruntime::NodeArg>(m, "NodeArg", R"pbdoc(Node argument definition, for both input and output,
including arg name, arg type (contains both type and shape).)pbdoc")
.def_property_readonly("name", &onnxruntime::NodeArg::Name, "node name")
.def_property_readonly(
"type", [](const onnxruntime::NodeArg& na) -> std::string {
return *(na.Type());
},
"node type")
.def(
"__str__", [](const onnxruntime::NodeArg& na) -> std::string {
std::ostringstream res;
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
res << "[]";
} else {
res << "[";
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
res << shape->dim(i).dim_value();
} else if (utils::HasDimParam(shape->dim(i))) {
res << "'" << shape->dim(i).dim_param() << "'";
} else {
res << "None";
}
if (i < shape->dim_size() - 1) {
res << ", ";
}
}
res << "]";
}
res << ")";
return std::string(res.str());
},
"converts the node into a readable string")
.def_property_readonly(
"shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
return arr;
}
arr.resize(shape->dim_size());
for (int i = 0; i < shape->dim_size(); ++i) {
if (utils::HasDimValue(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_value());
} else if (utils::HasDimParam(shape->dim(i))) {
arr[i] = py::cast(shape->dim(i).dim_param());
} else {
arr[i] = py::none();
}
}
return arr;
},
"node shape (assuming the node holds a tensor)");
py::class_<SessionObjectInitializer> sessionObjectInitializer(m, "SessionObjectInitializer");
py::class_<PyInferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
// In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char*
// without any conversion. So this init method can be used for model file path (string) and model content (bytes)
.def(py::init([&env](const PySessionOptions& so, const std::string arg, bool is_arg_file_name,
bool load_config_from_model = false) {
std::unique_ptr<PyInferenceSession> sess;
// separate creation of the session from model loading unless we have to read the config from the model.
// in a minimal build we only support load via Load(...) and not at session creation time
if (load_config_from_model) {
#if !defined(ORT_MINIMAL_BUILD)
sess = std::make_unique<PyInferenceSession>(env, so, arg, is_arg_file_name);
RegisterCustomOpDomainsAndLibraries(sess.get(), so);
OrtPybindThrowIfError(sess->GetSessionHandle()->Load());
#else
ORT_THROW("Loading configuration from an ONNX model is not supported in this build.");
#endif
} else {
sess = std::make_unique<PyInferenceSession>(env, so);
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
RegisterCustomOpDomainsAndLibraries(sess.get(), so);
#endif
if (is_arg_file_name) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(arg));
} else {
OrtPybindThrowIfError(sess->GetSessionHandle()->Load(arg.data(), arg.size()));
}
}
return sess;
}))
.def(
"initialize_session",
[ep_registration_fn](PyInferenceSession* sess,
const std::vector<std::string>& provider_types = {},
const ProviderOptionsVector& provider_options = {},
const std::unordered_set<std::string>& disabled_optimizer_names = {}) {
InitializeSession(sess->GetSessionHandle(),
ep_registration_fn,
provider_types,
provider_options,
disabled_optimizer_names);
},
R"pbdoc(Load a model saved in ONNX or ORT format.)pbdoc")
.def("run",
[](PyInferenceSession* sess, std::vector<std::string> output_names,
std::map<std::string, py::object> pyfeeds, RunOptions* run_options = nullptr)
-> std::vector<py::object> {
NameMLValMap feeds;
for (auto feed : pyfeeds) {
// No need to process 'None's sent in by the user
// to feed Optional inputs in the graph.
// We just won't include anything in the feed and ORT
// will handle such implicit 'None's internally.
if (!feed.second.is(py::none())) {
OrtValue ml_value;
auto px = sess->GetSessionHandle()->GetModelInputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}
CreateGenericMLValue(px.second, GetAllocator(), feed.first, feed.second, &ml_value);
ThrowIfPyErrOccured();
feeds.insert(std::make_pair(feed.first, ml_value));
}
}
std::vector<OrtValue> fetches;
common::Status status;
{
// release GIL to allow multiple python threads to invoke Run() in parallel.
py::gil_scoped_release release;
if (run_options != nullptr) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(*run_options, feeds, output_names, &fetches));
} else {
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(feeds, output_names, &fetches));
}
}
std::vector<py::object> rfetch;
rfetch.reserve(fetches.size());
size_t pos = 0;
for (auto fet : fetches) {
if (fet.IsAllocated()) {
if (fet.IsTensor()) {
rfetch.push_back(AddTensorAsPyObj(fet, nullptr, nullptr));
} else if (fet.IsSparseTensor()) {
rfetch.push_back(GetPyObjectFromSparseTensor(pos, fet, nullptr));
} else {
rfetch.push_back(AddNonTensorAsPyObj(fet, nullptr, nullptr));
}
} else { // Send back None because the corresponding OrtValue was empty
rfetch.push_back(py::none());
}
++pos;
}
return rfetch;
})
/// This method accepts a dictionary of feeds (name -> OrtValue) and the list of output_names
/// and returns a list of python objects representing OrtValues. Each name may represent either
/// a Tensor, SparseTensor or a TensorSequence.
.def("run_with_ort_values", [](PyInferenceSession* sess, const py::dict& feeds, const std::vector<std::string>& output_names, RunOptions* run_options = nullptr) -> std::vector<OrtValue> {
NameMLValMap ort_feeds;
// item is always a copy since dict returns a value and not a ref
// and Apple XToolChain barks
for (const auto item : feeds) {
auto name = item.first.cast<std::string>();
const OrtValue* ort_value = item.second.cast<const OrtValue*>();
ort_feeds.emplace(name, *ort_value);
}
std::vector<OrtValue> fetches;
{
// release GIL to allow multiple python threads to invoke Run() in parallel.
py::gil_scoped_release release;
if (run_options != nullptr) {
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(*run_options, ort_feeds, output_names, &fetches));
} else {
OrtPybindThrowIfError(sess->GetSessionHandle()->Run(ort_feeds, output_names, &fetches));
}
}
return fetches;
})
.def("end_profiling", [](const PyInferenceSession* sess) -> std::string {
return sess->GetSessionHandle()->EndProfiling();
})
.def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t {
return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs();
})
.def(
"get_providers", [](const PyInferenceSession* sess) -> const std::vector<std::string>& {
return sess->GetSessionHandle()->GetRegisteredProviderTypes();
},
py::return_value_policy::reference_internal)
.def(
"get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& {
return sess->GetSessionHandle()->GetAllProviderOptions();
},
py::return_value_policy::reference_internal)
.def_property_readonly(
"session_options", [](const PyInferenceSession* sess) -> const PySessionOptions& {
const auto& session_options = sess->GetSessionHandle()->GetSessionOptions();
return static_cast<const PySessionOptions&>(session_options);
},
py::return_value_policy::reference_internal)
.def_property_readonly(
"inputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetSessionHandle()->GetModelInputs();
OrtPybindThrowIfError(res.first);
return *(res.second);
},
py::return_value_policy::reference_internal)
.def_property_readonly(
"outputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetSessionHandle()->GetModelOutputs();
OrtPybindThrowIfError(res.first);
return *(res.second);
},
py::return_value_policy::reference_internal)
.def_property_readonly(
"overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
auto res = sess->GetSessionHandle()->GetOverridableInitializers();
OrtPybindThrowIfError(res.first);
return *(res.second);
},
py::return_value_policy::reference_internal)
.def_property_readonly(
"model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& {
auto res = sess->GetSessionHandle()->GetModelMetadata();
OrtPybindThrowIfError(res.first);
return *(res.second);
},
py::return_value_policy::reference_internal)
.def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void {
Status status;
if (!run_options)
status = sess->GetSessionHandle()->Run(*io_binding.Get());
else
status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get());
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
});
py::enum_<onnxruntime::ArenaExtendStrategy>(m, "ArenaExtendStrategy", py::arithmetic())
.value("kNextPowerOfTwo", onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo)
.value("kSameAsRequested", onnxruntime::ArenaExtendStrategy::kSameAsRequested)
.export_values();
}